from __future__ import annotations

import hashlib
import json
import os
import ssl
from datetime import datetime, timezone
from typing import Any, Iterable
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

import asyncpg

from .types import SafeAction, dt_to_iso, normalize_addr, parse_dt


DDL_SAFE_ACTIONS = """
CREATE TABLE IF NOT EXISTS safe_btc5_actions (
    action_id TEXT PRIMARY KEY,
    safe TEXT NOT NULL,
    market_slug TEXT NOT NULL,
    condition_id TEXT,
    tx_hash TEXT,
    action_ts TIMESTAMPTZ NOT NULL,
    action_type TEXT NOT NULL,
    source TEXT,
    raw_type TEXT,
    outcome TEXT,
    side TEXT,
    price DOUBLE PRECISION,
    shares DOUBLE PRECISION,
    gross_usdc DOUBLE PRECISION,
    net_usdc_delta DOUBLE PRECISION,
    usdc_delta DOUBLE PRECISION,
    up_delta DOUBLE PRECISION,
    down_delta DOUBLE PRECISION,
    maker_taker_role TEXT,
    counterparties TEXT,
    pnl_eligible BOOLEAN,
    updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
    created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS safe_btc5_actions_safe_ts_idx
    ON safe_btc5_actions (safe, action_ts DESC);
CREATE INDEX IF NOT EXISTS safe_btc5_actions_slug_ts_idx
    ON safe_btc5_actions (market_slug, action_ts DESC);
"""

DDL_SAFE_INFERENCE = """
CREATE TABLE IF NOT EXISTS safe_btc5_strategy_inference (
    inference_id TEXT PRIMARY KEY,
    safe TEXT NOT NULL,
    market_slug TEXT NOT NULL,
    condition_id TEXT,
    title TEXT,
    market_start_utc TIMESTAMPTZ,
    trade_timestamp_utc TIMESTAMPTZ NOT NULL,
    trade_timestamp_et TEXT,
    date_et TEXT,
    week_start_et TEXT,
    seconds_after_start DOUBLE PRECISION,
    seconds_before_end DOUBLE PRECISION,
    tx_hash TEXT,
    sold_outcome TEXT,
    winning_outcome TEXT,
    market_resolved BOOLEAN,
    sold_side_won BOOLEAN,
    sold_side_lost BOOLEAN,
    maker_taker_role TEXT,
    counterparties TEXT,
    price DOUBLE PRECISION,
    shares DOUBLE PRECISION,
    gross_usdc DOUBLE PRECISION,
    net_usdc_delta_aux DOUBLE PRECISION,
    profit_if_sold_side_loses_usdc DOUBLE PRECISION,
    loss_if_sold_side_wins_usdc DOUBLE PRECISION,
    pnl_if_sold_side_wins_usdc DOUBLE PRECISION,
    actual_direction_pnl_usdc DOUBLE PRECISION,
    break_even_sold_side_loss_probability DOUBLE PRECISION,
    btc_open_price DOUBLE PRECISION,
    btc_final_price DOUBLE PRECISION,
    btc_trade_price_chainlink DOUBLE PRECISION,
    btc_offset_usd_chainlink DOUBLE PRECISION,
    btc_offset_pct_chainlink DOUBLE PRECISION,
    snapshot_ts_utc TIMESTAMPTZ,
    snapshot_distance_sec DOUBLE PRECISION,
    snapshot_time_remaining_sec DOUBLE PRECISION,
    raw_json JSONB NOT NULL,
    updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
    created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS safe_btc5_inference_safe_ts_idx
    ON safe_btc5_strategy_inference (safe, trade_timestamp_utc DESC);
CREATE INDEX IF NOT EXISTS safe_btc5_inference_slug_ts_idx
    ON safe_btc5_strategy_inference (market_slug, trade_timestamp_utc DESC);
CREATE INDEX IF NOT EXISTS safe_btc5_inference_side_idx
    ON safe_btc5_strategy_inference (sold_outcome, sold_side_lost);
"""

DDL_BINANCE_TICKS = """
CREATE TABLE IF NOT EXISTS btc_binance_ticks (
    id BIGSERIAL PRIMARY KEY,
    symbol TEXT NOT NULL,
    event_ts TIMESTAMPTZ,
    trade_ts TIMESTAMPTZ,
    received_at TIMESTAMPTZ NOT NULL,
    price DOUBLE PRECISION NOT NULL,
    quantity DOUBLE PRECISION,
    raw_json JSONB NOT NULL,
    created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS btc_binance_ticks_received_idx
    ON btc_binance_ticks (received_at DESC);
CREATE INDEX IF NOT EXISTS btc_binance_ticks_event_idx
    ON btc_binance_ticks (event_ts DESC);
"""


def database_url(default: str | None = None) -> str:
    dsn = os.environ.get("DATABASE_URL") or default
    if not dsn:
        raise RuntimeError("DATABASE_URL is required")
    return dsn


def split_dsn_ssl(dsn: str) -> tuple[str, ssl.SSLContext | None]:
    parsed = urlparse(dsn)
    params = parse_qs(parsed.query, keep_blank_values=True)
    ssl_mode = params.pop("sslmode", [None])[0]

    clean_query = urlencode({k: v[0] for k, v in params.items()})
    clean_dsn = urlunparse(parsed._replace(query=clean_query))

    ssl_ctx: ssl.SSLContext | None = None
    if ssl_mode in ("require", "verify-ca", "verify-full"):
        ssl_ctx = ssl.create_default_context()
        if ssl_mode == "require":
            ssl_ctx.check_hostname = False
            ssl_ctx.verify_mode = ssl.CERT_NONE
    return clean_dsn, ssl_ctx


def _float_or_none(value: Any) -> float | None:
    if value in ("", None):
        return None
    try:
        return float(value)
    except (TypeError, ValueError):
        return None


def _bool_or_none(value: Any) -> bool | None:
    if value in ("", None):
        return None
    if isinstance(value, bool):
        return value
    text = str(value).lower()
    if text == "true":
        return True
    if text == "false":
        return False
    return None


def _json_dumps(value: Any) -> str:
    return json.dumps(value, ensure_ascii=False, sort_keys=True)


def _inference_id(safe: str, row: dict[str, Any]) -> str:
    parts = [
        normalize_addr(safe),
        str(row.get("market_slug", "")),
        str(row.get("tx_hash", "")).lower(),
        str(row.get("trade_timestamp_utc", "")),
        str(row.get("sold_outcome", "")),
        str(row.get("shares", "")),
        str(row.get("gross_usdc", "")),
    ]
    return hashlib.sha256("|".join(parts).encode("utf-8")).hexdigest()


class SafeBtc5DbStore:
    def __init__(self, dsn: str | None = None) -> None:
        self.dsn = database_url(dsn)
        self.conn: asyncpg.Connection | None = None

    async def __aenter__(self) -> "SafeBtc5DbStore":
        await self.connect()
        await self.init_schema()
        return self

    async def __aexit__(self, *_exc: object) -> None:
        await self.close()

    async def connect(self) -> None:
        clean_dsn, ssl_ctx = split_dsn_ssl(self.dsn)
        self.conn = await asyncpg.connect(clean_dsn, ssl=ssl_ctx)

    async def close(self) -> None:
        if self.conn is not None:
            await self.conn.close()
            self.conn = None

    async def init_schema(self) -> None:
        assert self.conn is not None
        for ddl in (DDL_SAFE_ACTIONS, DDL_SAFE_INFERENCE, DDL_BINANCE_TICKS):
            for stmt in ddl.strip().split(";"):
                stmt = stmt.strip()
                if stmt:
                    await self.conn.execute(stmt)

    async def upsert_actions(self, actions: Iterable[SafeAction]) -> int:
        assert self.conn is not None
        rows = [
            (
                action.action_id,
                normalize_addr(action.safe),
                action.market_slug,
                action.condition_id,
                action.tx_hash.lower(),
                action.timestamp,
                action.action_type,
                action.source,
                action.raw_type,
                action.outcome,
                action.side,
                action.price,
                action.shares,
                action.gross_usdc,
                action.net_usdc_delta,
                action.usdc_delta,
                action.up_delta,
                action.down_delta,
                action.maker_taker_role,
                action.counterparties,
                action.pnl_eligible,
            )
            for action in actions
        ]
        if not rows:
            return 0
        await self.conn.executemany(
            """
            INSERT INTO safe_btc5_actions (
                action_id, safe, market_slug, condition_id, tx_hash, action_ts,
                action_type, source, raw_type, outcome, side, price, shares,
                gross_usdc, net_usdc_delta, usdc_delta, up_delta, down_delta,
                maker_taker_role, counterparties, pnl_eligible
            ) VALUES (
                $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21
            )
            ON CONFLICT (action_id) DO UPDATE SET
                safe = EXCLUDED.safe,
                market_slug = EXCLUDED.market_slug,
                condition_id = EXCLUDED.condition_id,
                tx_hash = EXCLUDED.tx_hash,
                action_ts = EXCLUDED.action_ts,
                action_type = EXCLUDED.action_type,
                source = EXCLUDED.source,
                raw_type = EXCLUDED.raw_type,
                outcome = EXCLUDED.outcome,
                side = EXCLUDED.side,
                price = EXCLUDED.price,
                shares = EXCLUDED.shares,
                gross_usdc = EXCLUDED.gross_usdc,
                net_usdc_delta = EXCLUDED.net_usdc_delta,
                usdc_delta = EXCLUDED.usdc_delta,
                up_delta = EXCLUDED.up_delta,
                down_delta = EXCLUDED.down_delta,
                maker_taker_role = EXCLUDED.maker_taker_role,
                counterparties = EXCLUDED.counterparties,
                pnl_eligible = EXCLUDED.pnl_eligible,
                updated_at = NOW()
            """,
            rows,
        )
        return len(rows)

    async def upsert_strategy_inference(
        self,
        rows: Iterable[dict[str, Any]],
        *,
        safe: str,
    ) -> int:
        assert self.conn is not None
        safe = normalize_addr(safe)
        values = []
        for row in rows:
            values.append(
                (
                    _inference_id(safe, row),
                    safe,
                    row.get("market_slug"),
                    row.get("condition_id"),
                    row.get("title"),
                    parse_dt(row.get("market_start_utc")),
                    parse_dt(row.get("trade_timestamp_utc")),
                    row.get("trade_timestamp_et"),
                    row.get("date_et"),
                    row.get("week_start_et"),
                    _float_or_none(row.get("seconds_after_start")),
                    _float_or_none(row.get("seconds_before_end")),
                    str(row.get("tx_hash", "")).lower(),
                    row.get("sold_outcome"),
                    row.get("winning_outcome"),
                    _bool_or_none(row.get("market_resolved")),
                    _bool_or_none(row.get("sold_side_won")),
                    _bool_or_none(row.get("sold_side_lost")),
                    row.get("maker_taker_role"),
                    row.get("counterparties"),
                    _float_or_none(row.get("price")),
                    _float_or_none(row.get("shares")),
                    _float_or_none(row.get("gross_usdc")),
                    _float_or_none(row.get("net_usdc_delta_aux")),
                    _float_or_none(row.get("profit_if_sold_side_loses_usdc")),
                    _float_or_none(row.get("loss_if_sold_side_wins_usdc")),
                    _float_or_none(row.get("pnl_if_sold_side_wins_usdc")),
                    _float_or_none(row.get("actual_direction_pnl_usdc")),
                    _float_or_none(row.get("break_even_sold_side_loss_probability")),
                    _float_or_none(row.get("btc_open_price")),
                    _float_or_none(row.get("btc_final_price")),
                    _float_or_none(row.get("btc_trade_price")),
                    _float_or_none(row.get("btc_offset_usd")),
                    _float_or_none(row.get("btc_offset_pct")),
                    parse_dt(row.get("snapshot_ts_utc")),
                    _float_or_none(row.get("snapshot_distance_sec")),
                    _float_or_none(row.get("snapshot_time_remaining_sec")),
                    _json_dumps(row),
                )
            )
        if not values:
            return 0
        await self.conn.executemany(
            """
            INSERT INTO safe_btc5_strategy_inference (
                inference_id, safe, market_slug, condition_id, title,
                market_start_utc, trade_timestamp_utc, trade_timestamp_et,
                date_et, week_start_et, seconds_after_start, seconds_before_end,
                tx_hash, sold_outcome, winning_outcome, market_resolved,
                sold_side_won, sold_side_lost, maker_taker_role, counterparties,
                price, shares, gross_usdc, net_usdc_delta_aux,
                profit_if_sold_side_loses_usdc, loss_if_sold_side_wins_usdc,
                pnl_if_sold_side_wins_usdc, actual_direction_pnl_usdc,
                break_even_sold_side_loss_probability, btc_open_price,
                btc_final_price, btc_trade_price_chainlink,
                btc_offset_usd_chainlink, btc_offset_pct_chainlink,
                snapshot_ts_utc, snapshot_distance_sec,
                snapshot_time_remaining_sec, raw_json
            ) VALUES (
                $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,
                $17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,
                $31,$32,$33,$34,$35,$36,$37,$38::jsonb
            )
            ON CONFLICT (inference_id) DO UPDATE SET
                safe = EXCLUDED.safe,
                market_slug = EXCLUDED.market_slug,
                condition_id = EXCLUDED.condition_id,
                title = EXCLUDED.title,
                market_start_utc = EXCLUDED.market_start_utc,
                trade_timestamp_utc = EXCLUDED.trade_timestamp_utc,
                trade_timestamp_et = EXCLUDED.trade_timestamp_et,
                date_et = EXCLUDED.date_et,
                week_start_et = EXCLUDED.week_start_et,
                seconds_after_start = EXCLUDED.seconds_after_start,
                seconds_before_end = EXCLUDED.seconds_before_end,
                tx_hash = EXCLUDED.tx_hash,
                sold_outcome = EXCLUDED.sold_outcome,
                winning_outcome = EXCLUDED.winning_outcome,
                market_resolved = EXCLUDED.market_resolved,
                sold_side_won = EXCLUDED.sold_side_won,
                sold_side_lost = EXCLUDED.sold_side_lost,
                maker_taker_role = EXCLUDED.maker_taker_role,
                counterparties = EXCLUDED.counterparties,
                price = EXCLUDED.price,
                shares = EXCLUDED.shares,
                gross_usdc = EXCLUDED.gross_usdc,
                net_usdc_delta_aux = EXCLUDED.net_usdc_delta_aux,
                profit_if_sold_side_loses_usdc = EXCLUDED.profit_if_sold_side_loses_usdc,
                loss_if_sold_side_wins_usdc = EXCLUDED.loss_if_sold_side_wins_usdc,
                pnl_if_sold_side_wins_usdc = EXCLUDED.pnl_if_sold_side_wins_usdc,
                actual_direction_pnl_usdc = EXCLUDED.actual_direction_pnl_usdc,
                break_even_sold_side_loss_probability = EXCLUDED.break_even_sold_side_loss_probability,
                btc_open_price = EXCLUDED.btc_open_price,
                btc_final_price = EXCLUDED.btc_final_price,
                btc_trade_price_chainlink = EXCLUDED.btc_trade_price_chainlink,
                btc_offset_usd_chainlink = EXCLUDED.btc_offset_usd_chainlink,
                btc_offset_pct_chainlink = EXCLUDED.btc_offset_pct_chainlink,
                snapshot_ts_utc = EXCLUDED.snapshot_ts_utc,
                snapshot_distance_sec = EXCLUDED.snapshot_distance_sec,
                snapshot_time_remaining_sec = EXCLUDED.snapshot_time_remaining_sec,
                raw_json = EXCLUDED.raw_json,
                updated_at = NOW()
            """,
            values,
        )
        return len(values)

    async def insert_binance_tick(
        self,
        *,
        symbol: str,
        event_ts: datetime | None,
        trade_ts: datetime | None,
        received_at: datetime,
        price: float,
        quantity: float | None,
        raw_json: dict[str, Any],
    ) -> None:
        assert self.conn is not None
        await self.conn.execute(
            """
            INSERT INTO btc_binance_ticks (
                symbol, event_ts, trade_ts, received_at, price, quantity, raw_json
            ) VALUES ($1,$2,$3,$4,$5,$6,$7::jsonb)
            """,
            symbol,
            event_ts,
            trade_ts,
            received_at,
            price,
            quantity,
            _json_dumps(raw_json),
        )


def action_row_for_debug(action: SafeAction) -> dict[str, Any]:
    row = action.to_row()
    row["timestamp"] = dt_to_iso(action.timestamp)
    return row
