from __future__ import annotations

import asyncio
import json
import os
import ssl
import time
import uuid
from contextlib import suppress
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

import aiohttp
import websockets
from loguru import logger

from analysis.safe_btc5.clients import GAMMA_BASE
from analysis.safe_btc5.types import parse_dt

from .activity_trades import TargetMarket, select_target_markets
from .store import LplOrderbookStore, NoopLplOrderbookStore, dumps


CLOB_BASE = "https://clob.polymarket.com"
CLOB_WS = "wss://ws-subscriptions-clob.polymarket.com/ws/market"


def _json_list(raw: Any) -> list[Any]:
    if raw is None:
        return []
    if isinstance(raw, list):
        return raw
    if isinstance(raw, str):
        try:
            value = json.loads(raw)
        except json.JSONDecodeError:
            return []
        return value if isinstance(value, list) else []
    return []


def parse_ws_ts(value: Any) -> datetime | None:
    if value in (None, ""):
        return None
    if isinstance(value, str):
        parsed = parse_dt(value)
        if parsed is not None:
            return parsed
        try:
            value = float(value)
        except ValueError:
            return None
    if isinstance(value, (int, float)):
        ts = float(value)
        if ts > 1e12:
            ts /= 1000.0
        try:
            return datetime.fromtimestamp(ts, tz=timezone.utc)
        except (OSError, OverflowError, ValueError):
            return None
    return None


class OrderBook:
    def __init__(self) -> None:
        self.bids: dict[str, float] = {}
        self.asks: dict[str, float] = {}

    def apply_snapshot(self, data: dict[str, Any]) -> None:
        self.bids = {
            str(o["price"]): float(o.get("size", 0))
            for o in data.get("bids", [])
            if float(o.get("size", 0)) > 0
        }
        self.asks = {
            str(o["price"]): float(o.get("size", 0))
            for o in data.get("asks", [])
            if float(o.get("size", 0)) > 0
        }

    def apply_delta(self, changes: list[dict[str, Any]]) -> None:
        for change in changes:
            price = str(change.get("price", ""))
            if not price:
                continue
            side = self.bids if str(change.get("side", "")).upper() == "BUY" else self.asks
            size = float(change.get("size") or 0)
            if size == 0:
                side.pop(price, None)
            else:
                side[price] = size

    def best_bid(self) -> tuple[float | None, float | None]:
        if not self.bids:
            return None, None
        price = max(self.bids, key=lambda value: float(value))
        return float(price), self.bids[price]

    def best_ask(self) -> tuple[float | None, float | None]:
        if not self.asks:
            return None, None
        price = min(self.asks, key=lambda value: float(value))
        return float(price), self.asks[price]

    def bid_levels(self) -> list[dict[str, float]]:
        return [
            {"price": float(price), "size": size}
            for price, size in sorted(
                self.bids.items(),
                key=lambda item: float(item[0]),
                reverse=True,
            )
        ]

    def ask_levels(self) -> list[dict[str, float]]:
        return [
            {"price": float(price), "size": size}
            for price, size in sorted(
                self.asks.items(),
                key=lambda item: float(item[0]),
            )
        ]


@dataclass
class LplMarket:
    slug: str
    condition_id: str
    question: str
    outcomes: list[str]
    token_ids: list[str]
    target_markets: list[TargetMarket] | None = None

    @property
    def token_to_outcome(self) -> dict[str, str]:
        if self.target_markets:
            output: dict[str, str] = {}
            for target in self.target_markets:
                for idx, token_id in enumerate(target.token_ids):
                    outcome = target.outcomes[idx] if idx < len(target.outcomes) else ""
                    output[token_id] = f"{target.kind}:{outcome}" if outcome else target.kind
            return output
        return {
            token_id: self.outcomes[idx] if idx < len(self.outcomes) else ""
            for idx, token_id in enumerate(self.token_ids)
        }


class JsonlWriter:
    def __init__(
        self,
        path: Path,
        *,
        flush_every: int = 100,
        flush_interval_seconds: float = 0.5,
    ) -> None:
        self.path = path
        self.flush_every = max(1, flush_every)
        self.flush_interval_seconds = max(0.0, flush_interval_seconds)
        self._pending = 0
        self._last_flush_monotonic = time.monotonic()
        self.path.parent.mkdir(parents=True, exist_ok=True)
        self._file = path.open("a", encoding="utf-8")

    def write(self, row: dict[str, Any]) -> None:
        self._file.write(dumps(row))
        self._file.write("\n")
        self._pending += 1
        now = time.monotonic()
        if (
            self._pending >= self.flush_every
            or now - self._last_flush_monotonic >= self.flush_interval_seconds
        ):
            self.flush()

    def flush(self) -> None:
        self._file.flush()
        self._pending = 0
        self._last_flush_monotonic = time.monotonic()

    def close(self) -> None:
        self.flush()
        self._file.close()


class LplOrderbookRecorder:
    def __init__(
        self,
        *,
        slug: str,
        output_root: Path = Path("data/lpl"),
        snapshot_interval_seconds: float = 1.0,
        reconnect_delay_seconds: float = 3.0,
        insecure_ssl: bool = False,
        db_dsn: str | None = None,
        status_interval_seconds: float = 5.0,
        jsonl_flush_every: int = 100,
        jsonl_flush_interval_seconds: float = 0.5,
    ) -> None:
        self.slug = slug
        self.output_root = output_root
        self.snapshot_interval_seconds = snapshot_interval_seconds
        self.reconnect_delay_seconds = reconnect_delay_seconds
        self.insecure_ssl = insecure_ssl
        self.db_dsn = db_dsn
        self.status_interval_seconds = status_interval_seconds
        self.jsonl_flush_every = jsonl_flush_every
        self.jsonl_flush_interval_seconds = jsonl_flush_interval_seconds
        self.session_id = str(uuid.uuid4())
        self.stop_event = asyncio.Event()
        self.message_index = 0
        self.connection_id = 0
        self.reconnect_count = 0
        self.events_written = 0
        self.snapshots_written = 0
        self.started_at_wall: datetime | None = None
        self.last_received_at_wall: datetime | None = None
        self.last_exchange_ts: datetime | None = None
        self.last_event_type = ""
        self.last_status_write_monotonic = 0.0
        self.books: dict[str, OrderBook] = {}
        self.market: LplMarket | None = None
        self.events_jsonl: JsonlWriter | None = None
        self.snapshots_jsonl: JsonlWriter | None = None
        self.status_path: Path | None = None
        self.output_dir: Path | None = None

    def stop(self) -> None:
        self.stop_event.set()

    async def run(self) -> None:
        output_dir = self.output_root / self.slug
        output_dir.mkdir(parents=True, exist_ok=True)
        self.output_dir = output_dir
        self.status_path = output_dir / "recording_status.json"
        self.started_at_wall = datetime.now(timezone.utc)
        self.events_jsonl = JsonlWriter(
            output_dir / "orderbook_events.jsonl",
            flush_every=self.jsonl_flush_every,
            flush_interval_seconds=self.jsonl_flush_interval_seconds,
        )
        self.snapshots_jsonl = JsonlWriter(
            output_dir / "orderbook_snapshots.jsonl",
            flush_every=self.jsonl_flush_every,
            flush_interval_seconds=self.jsonl_flush_interval_seconds,
        )

        connector = aiohttp.TCPConnector(ssl=False) if self.insecure_ssl else None
        async with aiohttp.ClientSession(connector=connector) as http:
            self.market = await self._fetch_market(http)
            self.books = {token_id: OrderBook() for token_id in self.market.token_ids}
            meta = {
                "session_id": self.session_id,
                "slug": self.market.slug,
                "condition_id": self.market.condition_id,
                "question": self.market.question,
                "outcomes": self.market.outcomes,
                "token_ids": self.market.token_ids,
                "target_markets": [
                    target.__dict__ for target in self.market.target_markets or []
                ],
                "started_at": datetime.now(timezone.utc).isoformat(),
            }
            (output_dir / "recording_meta.json").write_text(
                json.dumps(meta, ensure_ascii=False, indent=2, sort_keys=True) + "\n",
                encoding="utf-8",
            )

            store_factory = (
                LplOrderbookStore(self.db_dsn)
                if self.db_dsn or os.getenv("DATABASE_URL")
                else NoopLplOrderbookStore()
            )
            async with store_factory as store:
                await store.create_session(
                    session_id=self.session_id,
                    slug=self.market.slug,
                    condition_id=self.market.condition_id,
                    question=self.market.question,
                    outcomes=self.market.outcomes,
                    token_ids=self.market.token_ids,
                    output_dir=str(output_dir),
                    started_at=datetime.now(timezone.utc),
                )
                status = "stopped_by_user"
                notes = ""
                snapshot_task = asyncio.create_task(
                    self._periodic_snapshots(store),
                    name=f"lpl_snapshot_{self.slug}",
                )
                status_task = asyncio.create_task(
                    self._periodic_status("running"),
                    name=f"lpl_status_{self.slug}",
                )
                try:
                    await self._load_initial_books(http, store)
                    self._write_status("running", force=True)
                    await self._ws_loop(store, http)
                except asyncio.CancelledError:
                    status = "cancelled"
                    raise
                except Exception as exc:
                    status = "error"
                    notes = str(exc)
                    raise
                finally:
                    self.stop_event.set()
                    snapshot_task.cancel()
                    status_task.cancel()
                    with suppress(asyncio.CancelledError):
                        await snapshot_task
                    with suppress(asyncio.CancelledError):
                        await status_task
                    await store.finish_session(
                        session_id=self.session_id,
                        ended_at=datetime.now(timezone.utc),
                        status=status,
                        notes=notes,
                    )
                    if self.events_jsonl:
                        self.events_jsonl.close()
                    if self.snapshots_jsonl:
                        self.snapshots_jsonl.close()
                    self._write_status(status, force=True, notes=notes)

    async def _fetch_market(self, http: aiohttp.ClientSession) -> LplMarket:
        async with http.get(
            f"{GAMMA_BASE}/markets/slug/{self.slug}",
            timeout=aiohttp.ClientTimeout(total=20),
        ) as resp:
            resp.raise_for_status()
            raw = await resp.json()
        outcomes = [str(item) for item in _json_list(raw.get("outcomes"))]
        token_ids = [str(item) for item in _json_list(raw.get("clobTokenIds"))]
        if not token_ids:
            raise RuntimeError(f"market has no clobTokenIds: {self.slug}")
        return LplMarket(
            slug=str(raw.get("slug") or self.slug),
            condition_id=str(raw.get("conditionId") or ""),
            question=str(raw.get("question") or ""),
            outcomes=outcomes,
            token_ids=token_ids,
        )

    async def _load_initial_books(
        self,
        http: aiohttp.ClientSession,
        store: LplOrderbookStore,
    ) -> None:
        assert self.market is not None
        for token_id in self.market.token_ids:
            async with http.get(
                f"{CLOB_BASE}/book",
                params={"token_id": token_id},
                timeout=aiohttp.ClientTimeout(total=10),
            ) as resp:
                resp.raise_for_status()
                raw = await resp.json()
            self.books[token_id].apply_snapshot(raw)
            await self._record_event(
                store,
                raw=raw,
                event_type="rest_book",
                asset_id=token_id,
                exchange_ts=None,
            )
            await self._record_snapshot(store, token_id, "initial_rest")

    async def _ws_loop(
        self,
        store: LplOrderbookStore,
        http: aiohttp.ClientSession,
    ) -> None:
        assert self.market is not None
        sub_msg = json.dumps(
            {
                "assets_ids": self.market.token_ids,
                "type": "market",
                "custom_feature_enabled": True,
            }
        )
        while not self.stop_event.is_set():
            self.connection_id += 1
            is_reconnect = self.connection_id > 1
            if is_reconnect:
                self.reconnect_count += 1
            try:
                await self._record_event(
                    store,
                    raw={
                        "event_type": "reconnect",
                        "connection_id": self.connection_id,
                        "is_reconnect": is_reconnect,
                        "last_received_at_wall": (
                            self.last_received_at_wall.isoformat()
                            if self.last_received_at_wall
                            else ""
                        ),
                        "reconnect_count": self.reconnect_count,
                    },
                    event_type="reconnect",
                    asset_id="",
                    exchange_ts=None,
                )
                if is_reconnect:
                    await self._reload_all_books(http, store, "rest_book_after_reconnect")
                connect_kwargs: dict[str, Any] = {
                    "ping_interval": None,
                    "open_timeout": 10,
                }
                if self.insecure_ssl:
                    connect_kwargs["ssl"] = ssl._create_unverified_context()
                async with websockets.connect(CLOB_WS, **connect_kwargs) as ws:
                    await ws.send(sub_msg)
                    ping_task = asyncio.create_task(self._ping(ws))
                    try:
                        while not self.stop_event.is_set():
                            try:
                                raw_msg = await asyncio.wait_for(ws.recv(), timeout=1.0)
                            except asyncio.TimeoutError:
                                continue
                            if isinstance(raw_msg, bytes):
                                raw_msg = raw_msg.decode()
                            if not raw_msg or not raw_msg.strip().startswith("{"):
                                continue
                            try:
                                msg = json.loads(raw_msg)
                            except json.JSONDecodeError:
                                continue
                            await self._handle_ws_message(store, http, msg)
                    finally:
                        ping_task.cancel()
                        with suppress(asyncio.CancelledError):
                            await ping_task
            except asyncio.CancelledError:
                raise
            except Exception as exc:
                if not self.stop_event.is_set():
                    logger.warning(f"LPL CLOB WS 断线: {exc}; 重新连接")
                    await asyncio.sleep(self.reconnect_delay_seconds)

    async def _handle_ws_message(
        self,
        store: LplOrderbookStore,
        http: aiohttp.ClientSession,
        msg: dict[str, Any],
    ) -> None:
        assert self.market is not None
        event_type = str(msg.get("event_type") or "unknown")
        exchange_ts = parse_ws_ts(msg.get("timestamp"))
        asset_id = str(msg.get("asset_id") or "")
        await self._record_event(
            store,
            raw=msg,
            event_type=event_type,
            asset_id=asset_id,
            exchange_ts=exchange_ts,
        )

        if event_type == "book" and asset_id in self.books:
            self.books[asset_id].apply_snapshot(msg)
        elif event_type == "price_change":
            grouped: dict[str, list[dict[str, Any]]] = {}
            for change in msg.get("price_changes", []):
                token_id = str(change.get("asset_id") or "")
                if token_id in self.books:
                    grouped.setdefault(token_id, []).append(change)
            for token_id, changes in grouped.items():
                self.books[token_id].apply_delta(changes)
        elif event_type == "tick_size_change" and asset_id in self.books:
            await self._reload_book(http, store, asset_id, "rest_book_after_tick_size_change")

    async def _reload_all_books(
        self,
        http: aiohttp.ClientSession,
        store: LplOrderbookStore,
        event_type: str,
    ) -> None:
        for token_id in list(self.books):
            await self._reload_book(http, store, token_id, event_type)

    async def _reload_book(
        self,
        http: aiohttp.ClientSession,
        store: LplOrderbookStore,
        token_id: str,
        event_type: str,
    ) -> None:
        async with http.get(
            f"{CLOB_BASE}/book",
            params={"token_id": token_id},
            timeout=aiohttp.ClientTimeout(total=10),
        ) as resp:
            resp.raise_for_status()
            raw = await resp.json()
        self.books[token_id].apply_snapshot(raw)
        await self._record_event(
            store,
            raw=raw,
            event_type=event_type,
            asset_id=token_id,
            exchange_ts=None,
        )
        await self._record_snapshot(store, token_id, event_type)

    async def _periodic_snapshots(self, store: LplOrderbookStore) -> None:
        while not self.stop_event.is_set():
            await asyncio.sleep(self.snapshot_interval_seconds)
            for token_id in list(self.books):
                await self._record_snapshot(store, token_id, "periodic")

    async def _periodic_status(self, status: str) -> None:
        while not self.stop_event.is_set():
            self._write_status(status)
            await asyncio.sleep(self.status_interval_seconds)

    async def _record_event(
        self,
        store: LplOrderbookStore,
        *,
        raw: dict[str, Any],
        event_type: str,
        asset_id: str,
        exchange_ts: datetime | None,
    ) -> None:
        assert self.market is not None
        self.message_index += 1
        received_at_wall = datetime.now(timezone.utc)
        received_at_monotonic_ns = time.monotonic_ns()
        self.last_received_at_wall = received_at_wall
        self.last_exchange_ts = exchange_ts
        self.last_event_type = event_type
        outcome = self.market.token_to_outcome.get(asset_id, "")
        row = {
            "session_id": self.session_id,
            "slug": self.market.slug,
            "condition_id": self.market.condition_id,
            "asset_id": asset_id,
            "outcome": outcome,
            "event_type": event_type,
            "message_index": self.message_index,
            "connection_id": self.connection_id,
            "exchange_ts": exchange_ts.isoformat() if exchange_ts else "",
            "received_at_wall": received_at_wall.isoformat(),
            "received_at_monotonic_ns": received_at_monotonic_ns,
            "raw_json": raw,
        }
        assert self.events_jsonl is not None
        self.events_jsonl.write(row)
        self.events_written += 1
        self._write_status("running")
        await store.insert_event(
            session_id=self.session_id,
            slug=self.market.slug,
            condition_id=self.market.condition_id,
            asset_id=asset_id,
            outcome=outcome,
            event_type=event_type,
            message_index=self.message_index,
            connection_id=self.connection_id,
            exchange_ts=exchange_ts,
            received_at_wall=received_at_wall,
            received_at_monotonic_ns=received_at_monotonic_ns,
            raw_json=raw,
        )

    async def _record_snapshot(
        self,
        store: LplOrderbookStore,
        token_id: str,
        source: str,
    ) -> None:
        assert self.market is not None
        book = self.books[token_id]
        best_bid, _bid_size = book.best_bid()
        best_ask, _ask_size = book.best_ask()
        snapshot_at_wall = datetime.now(timezone.utc)
        outcome = self.market.token_to_outcome.get(token_id, "")
        bid_levels = book.bid_levels()
        ask_levels = book.ask_levels()
        row = {
            "session_id": self.session_id,
            "slug": self.market.slug,
            "asset_id": token_id,
            "outcome": outcome,
            "snapshot_at_wall": snapshot_at_wall.isoformat(),
            "message_index": self.message_index,
            "best_bid": best_bid,
            "best_ask": best_ask,
            "bid_levels": bid_levels,
            "ask_levels": ask_levels,
            "source": source,
        }
        assert self.snapshots_jsonl is not None
        self.snapshots_jsonl.write(row)
        self.snapshots_written += 1
        await store.insert_snapshot(
            session_id=self.session_id,
            slug=self.market.slug,
            asset_id=token_id,
            outcome=outcome,
            snapshot_at_wall=snapshot_at_wall,
            message_index=self.message_index,
            best_bid=best_bid,
            best_ask=best_ask,
            bid_levels=bid_levels,
            ask_levels=ask_levels,
            source=source,
        )

    def _write_status(self, status: str, *, force: bool = False, notes: str = "") -> None:
        if self.status_path is None:
            return
        now_monotonic = time.monotonic()
        if (
            not force
            and now_monotonic - self.last_status_write_monotonic < self.status_interval_seconds
        ):
            return
        self.last_status_write_monotonic = now_monotonic
        now = datetime.now(timezone.utc)
        last_age = (
            round((now - self.last_received_at_wall).total_seconds(), 3)
            if self.last_received_at_wall is not None
            else None
        )
        payload = {
            "session_id": self.session_id,
            "slug": self.market.slug if self.market else self.slug,
            "status": status,
            "notes": notes,
            "updated_at": now.isoformat(),
            "started_at": self.started_at_wall.isoformat() if self.started_at_wall else "",
            "output_dir": str(self.output_dir or ""),
            "snapshot_interval_seconds": self.snapshot_interval_seconds,
            "status_interval_seconds": self.status_interval_seconds,
            "jsonl_flush_every": self.jsonl_flush_every,
            "jsonl_flush_interval_seconds": self.jsonl_flush_interval_seconds,
            "connection_id": self.connection_id,
            "reconnect_count": self.reconnect_count,
            "message_index": self.message_index,
            "events_written": self.events_written,
            "snapshots_written": self.snapshots_written,
            "last_event_type": self.last_event_type,
            "last_exchange_ts": self.last_exchange_ts.isoformat() if self.last_exchange_ts else "",
            "last_received_at_wall": (
                self.last_received_at_wall.isoformat() if self.last_received_at_wall else ""
            ),
            "last_message_age_seconds": last_age,
            "token_count": len(self.books),
        }
        self.status_path.write_text(
            json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True) + "\n",
            encoding="utf-8",
        )

    @staticmethod
    async def _ping(ws: Any) -> None:
        try:
            while True:
                await asyncio.sleep(10)
                await ws.send("PING")
        except Exception:
            pass


class LplEventOrderbookRecorder(LplOrderbookRecorder):
    def __init__(
        self,
        *,
        event_slug: str,
        markets: set[str] | None = None,
        output_root: Path = Path("data/lpl"),
        snapshot_interval_seconds: float = 0.5,
        reconnect_delay_seconds: float = 3.0,
        insecure_ssl: bool = False,
        db_dsn: str | None = None,
        status_interval_seconds: float = 5.0,
        jsonl_flush_every: int = 100,
        jsonl_flush_interval_seconds: float = 0.5,
    ) -> None:
        super().__init__(
            slug=event_slug,
            output_root=output_root,
            snapshot_interval_seconds=snapshot_interval_seconds,
            reconnect_delay_seconds=reconnect_delay_seconds,
            insecure_ssl=insecure_ssl,
            db_dsn=db_dsn,
            status_interval_seconds=status_interval_seconds,
            jsonl_flush_every=jsonl_flush_every,
            jsonl_flush_interval_seconds=jsonl_flush_interval_seconds,
        )
        self.event_slug = event_slug
        self.requested_markets = markets

    async def _fetch_market(self, http: aiohttp.ClientSession) -> LplMarket:
        async with http.get(
            f"{GAMMA_BASE}/events/slug/{self.event_slug}",
            timeout=aiohttp.ClientTimeout(total=20),
        ) as resp:
            resp.raise_for_status()
            raw = await resp.json()

        targets = select_target_markets(
            self.event_slug,
            raw.get("markets") or [],
            self.requested_markets,
        )
        if not targets:
            raise RuntimeError(
                f"event has no requested target markets: "
                f"event_slug={self.event_slug} markets={sorted(self.requested_markets)}"
            )

        token_ids: list[str] = []
        outcomes: list[str] = []
        for target in targets:
            for idx, token_id in enumerate(target.token_ids):
                token_ids.append(token_id)
                outcome = target.outcomes[idx] if idx < len(target.outcomes) else token_id
                outcomes.append(f"{target.kind}:{outcome}")

        if len(set(token_ids)) != len(token_ids):
            raise RuntimeError(f"duplicate token ids in event targets: {self.event_slug}")

        return LplMarket(
            slug=str(raw.get("slug") or self.event_slug),
            condition_id="",
            question=str(raw.get("title") or raw.get("ticker") or self.event_slug),
            outcomes=outcomes,
            token_ids=token_ids,
            target_markets=targets,
        )
