from __future__ import annotations

import asyncio
import json
from datetime import datetime, timedelta, timezone
from typing import Any, Iterable, Optional

import aiohttp
from aiohttp import ClientResponseError

from gamma_event_parse import extract_ref_prices

from .types import (
    BTC_PREFIX,
    SESSION_SEC,
    MarketInfo,
    floor_session_ts,
    parse_dt,
    slug_start_ts,
)


DATA_API_BASE = "https://data-api.polymarket.com"
GAMMA_BASE = "https://gamma-api.polymarket.com"
POLYGON_RPC = "https://polygon.drpc.org"


def _json_list(raw: Any) -> list[Any]:
    if raw is None:
        return []
    if isinstance(raw, list):
        return raw
    if isinstance(raw, str):
        try:
            value = json.loads(raw)
        except json.JSONDecodeError:
            return []
        return value if isinstance(value, list) else []
    return []


def _is_btc_5m_slug(slug: str) -> bool:
    return bool(slug and slug.startswith(BTC_PREFIX))


class PolymarketDataClient:
    def __init__(
        self,
        http: aiohttp.ClientSession,
        data_api_base: str = DATA_API_BASE,
        gamma_base: str = GAMMA_BASE,
        rpc_url: str = POLYGON_RPC,
    ) -> None:
        self.http = http
        self.data_api_base = data_api_base.rstrip("/")
        self.gamma_base = gamma_base.rstrip("/")
        self.rpc_url = rpc_url

    async def _get_json(self, url: str, params: Optional[dict[str, Any]] = None) -> Any:
        async with self.http.get(
            url,
            params=params,
            timeout=aiohttp.ClientTimeout(total=20),
        ) as resp:
            resp.raise_for_status()
            return await resp.json()

    async def fetch_trades_page(
        self,
        safe: str,
        *,
        limit: int,
        offset: int = 0,
    ) -> list[dict[str, Any]]:
        data = await self._get_json(
            f"{self.data_api_base}/trades",
            {
                "user": safe,
                "takerOnly": "false",
                "limit": limit,
                "offset": offset,
            },
        )
        return data if isinstance(data, list) else []

    async def fetch_recent_trades(
        self,
        safe: str,
        *,
        limit: int = 10_000,
        page_size: int = 500,
    ) -> list[dict[str, Any]]:
        try:
            page = await self.fetch_trades_page(safe, limit=limit, offset=0)
            return [r for r in page if _is_btc_5m_slug(r.get("slug", ""))]
        except ClientResponseError:
            pass

        rows: list[dict[str, Any]] = []
        offset = 0
        while len(rows) < limit:
            page_limit = min(page_size, limit - len(rows))
            try:
                page = await self.fetch_trades_page(safe, limit=page_limit, offset=offset)
            except ClientResponseError as exc:
                if exc.status == 400:
                    break
                raise
            if not page:
                break
            rows.extend(page)
            if len(page) < page_limit:
                break
            offset += len(page)
            await asyncio.sleep(0.05)
        return [r for r in rows if _is_btc_5m_slug(r.get("slug", ""))]

    async def fetch_trades_since(
        self,
        safe: str,
        since_ts: int,
        *,
        page_size: int = 500,
        max_pages: int = 20,
    ) -> list[dict[str, Any]]:
        rows: list[dict[str, Any]] = []
        offset = 0
        for _ in range(max_pages):
            try:
                page = await self.fetch_trades_page(safe, limit=page_size, offset=offset)
            except ClientResponseError as exc:
                if exc.status == 400:
                    break
                raise
            if not page:
                break
            rows.extend([r for r in page if int(r.get("timestamp", 0)) >= since_ts])
            oldest = min(int(r.get("timestamp", 0)) for r in page)
            if oldest < since_ts or len(page) < page_size:
                break
            offset += len(page)
            await asyncio.sleep(0.05)
        return [r for r in rows if _is_btc_5m_slug(r.get("slug", ""))]

    async def fetch_activity_window(
        self,
        safe: str,
        *,
        start_ts: int,
        end_ts: int,
        limit: int = 500,
    ) -> list[dict[str, Any]]:
        rows: list[dict[str, Any]] = []
        offset = 0
        seen_first_keys: set[tuple[str, int, str]] = set()
        while True:
            try:
                data = await self._get_json(
                    f"{self.data_api_base}/activity",
                    {
                        "user": safe,
                        "start": start_ts,
                        "end": end_ts,
                        "limit": limit,
                        "offset": offset,
                        "sortDirection": "ASC",
                    },
                )
            except ClientResponseError as exc:
                if exc.status == 400:
                    break
                raise
            page = data if isinstance(data, list) else []
            if not page:
                break

            first = page[0]
            first_key = (
                str(first.get("transactionHash", "")),
                int(first.get("timestamp", 0)),
                str(first.get("type", "")),
            )
            if first_key in seen_first_keys:
                break
            seen_first_keys.add(first_key)

            rows.extend(page)
            if len(page) < limit:
                break
            offset += len(page)
            await asyncio.sleep(0.05)
        return rows

    async def fetch_activity_range(
        self,
        safe: str,
        *,
        start_ts: int,
        end_ts: int,
        chunk_hours: int = 24,
    ) -> list[dict[str, Any]]:
        rows: list[dict[str, Any]] = []
        cursor = datetime.fromtimestamp(start_ts, tz=timezone.utc)
        end_dt = datetime.fromtimestamp(end_ts, tz=timezone.utc)
        delta = timedelta(hours=chunk_hours)
        while cursor < end_dt:
            nxt = min(cursor + delta, end_dt)
            rows.extend(
                await self.fetch_activity_window(
                    safe,
                    start_ts=int(cursor.timestamp()),
                    end_ts=int(nxt.timestamp()),
                )
            )
            cursor = nxt
            await asyncio.sleep(0.05)
        return [r for r in rows if _is_btc_5m_slug(r.get("slug", ""))]

    async def fetch_market_by_slug(self, slug: str) -> MarketInfo:
        skeleton = market_skeleton(slug)
        try:
            data = await self._get_json(f"{self.gamma_base}/markets/slug/{slug}")
        except Exception:
            return skeleton
        if not isinstance(data, dict) or not data:
            return skeleton
        return parse_gamma_market(data, slug)

    async def fetch_markets_by_slugs(
        self,
        slugs: Iterable[str],
        *,
        concurrency: int = 8,
    ) -> dict[str, MarketInfo]:
        sem = asyncio.Semaphore(concurrency)

        async def one(slug: str) -> tuple[str, MarketInfo]:
            async with sem:
                market = await self.fetch_market_by_slug(slug)
                await asyncio.sleep(0.02)
                return slug, market

        pairs = await asyncio.gather(*(one(slug) for slug in slugs))
        return dict(pairs)

    async def fetch_receipt(self, tx_hash: str) -> Optional[dict[str, Any]]:
        payload = {
            "jsonrpc": "2.0",
            "method": "eth_getTransactionReceipt",
            "params": [tx_hash],
            "id": 1,
        }
        async with self.http.post(
            self.rpc_url,
            json=payload,
            timeout=aiohttp.ClientTimeout(total=20),
        ) as resp:
            resp.raise_for_status()
            data = await resp.json()
        result = data.get("result") if isinstance(data, dict) else None
        return result if isinstance(result, dict) else None


def market_skeleton(slug: str) -> MarketInfo:
    ts = slug_start_ts(slug)
    start = datetime.fromtimestamp(ts, tz=timezone.utc) if ts is not None else None
    end = (
        datetime.fromtimestamp(ts + SESSION_SEC, tz=timezone.utc)
        if ts is not None
        else None
    )
    return MarketInfo(
        slug=slug,
        event_slug=slug,
        start_time=start,
        end_time=end,
        source="skeleton",
    )


def generate_market_slugs(start_ts: int, end_ts: int) -> list[str]:
    start = floor_session_ts(start_ts)
    end = floor_session_ts(end_ts)
    return [f"{BTC_PREFIX}{ts}" for ts in range(start, end + SESSION_SEC, SESSION_SEC)]


def parse_gamma_market(raw: dict[str, Any], slug: str) -> MarketInfo:
    skeleton = market_skeleton(slug)
    outcomes = _json_list(raw.get("outcomes")) or ["Up", "Down"]
    token_ids = _json_list(raw.get("clobTokenIds"))

    up_idx = next((i for i, o in enumerate(outcomes) if str(o).lower() == "up"), 0)
    down_idx = next(
        (i for i, o in enumerate(outcomes) if str(o).lower() == "down"),
        1 if len(outcomes) > 1 else 0,
    )

    up_token = str(token_ids[up_idx]) if len(token_ids) > up_idx else ""
    down_token = str(token_ids[down_idx]) if len(token_ids) > down_idx else ""

    start_time = parse_dt(raw.get("eventStartTime")) or skeleton.start_time
    end_time = parse_dt(raw.get("endDate")) or skeleton.end_time

    event = {}
    events = raw.get("events")
    if isinstance(events, list) and events:
        event = events[0] if isinstance(events[0], dict) else {}
        start_time = parse_dt(event.get("startTime")) or start_time

    price_to_beat, final_price = extract_ref_prices(event, raw)
    outcome_prices = _json_list(raw.get("outcomePrices"))

    winning = ""
    resolved = bool(raw.get("closed")) or raw.get("umaResolutionStatus") == "resolved"
    if resolved and len(outcome_prices) >= len(outcomes):
        try:
            prices = [float(p) for p in outcome_prices]
            max_idx = max(range(len(prices)), key=lambda i: prices[i])
            if prices[max_idx] >= 0.999:
                winning = str(outcomes[max_idx])
        except (TypeError, ValueError):
            winning = ""

    if not winning and final_price is not None and price_to_beat is not None:
        winning = "Up" if final_price >= price_to_beat else "Down"
        resolved = True

    return MarketInfo(
        slug=slug,
        event_slug=raw.get("eventSlug") or slug,
        condition_id=raw.get("conditionId", ""),
        title=raw.get("question", ""),
        start_time=start_time,
        end_time=end_time,
        up_token_id=up_token,
        down_token_id=down_token,
        resolved=resolved,
        winning_outcome=winning,
        final_price=final_price,
        price_to_beat=price_to_beat,
        source="gamma",
    )
