"""
Invoice PDF Generator
Generates invoice PDFs from a Jinja2 HTML template using weasyprint.
"""
import ipaddress
import logging
import os
import socket
from datetime import datetime
from io import BytesIO
from typing import Optional
from urllib.parse import urlparse

import jinja2
import jinja2.sandbox
from jinja2 import FileSystemLoader, select_autoescape
from sqlalchemy.orm import Session
from weasyprint import HTML

logger = logging.getLogger(__name__)

_PRIVATE_NETWORKS = [
    ipaddress.ip_network("10.0.0.0/8"),
    ipaddress.ip_network("172.16.0.0/12"),
    ipaddress.ip_network("192.168.0.0/16"),
    ipaddress.ip_network("127.0.0.0/8"),
    ipaddress.ip_network("169.254.0.0/16"),
    ipaddress.ip_network("::1/128"),
    ipaddress.ip_network("fc00::/7"),
]


def _safe_url_fetcher(url: str):
    """Custom WeasyPrint URL fetcher that blocks SSRF vectors."""
    from weasyprint.urls import default_url_fetcher as _default_fetcher
    try:
        from src.core.config import settings as _settings
        _is_dev = (getattr(_settings, "APP_ENV", "dev") or "dev").lower() in ("dev", "development", "local")
    except Exception:
        _is_dev = True
    # In dev, skip all URL restrictions so local/HTTP logo URLs render correctly
    if _is_dev:
        return _default_fetcher(url)
    parsed = urlparse(url)
    if parsed.scheme not in ("https",):
        raise ValueError(f"Blocked non-HTTPS URL in PDF template: {parsed.scheme}://")
    try:
        ip = ipaddress.ip_address(socket.gethostbyname(parsed.hostname))
        for network in _PRIVATE_NETWORKS:
            if ip in network:
                raise ValueError(f"Blocked internal IP in PDF template URL: {ip}")
    except socket.gaierror:
        pass
    return _default_fetcher(url)

TEMPLATES_DIR = os.path.normpath(
    os.path.join(os.path.dirname(__file__), "..", "..", "..", "templates")
)


class InvoicePDFGenerator:
    """Generates a branded invoice PDF for a given Invoice ORM object."""

    def __init__(self) -> None:
        self.env = jinja2.sandbox.SandboxedEnvironment(
            loader=FileSystemLoader(TEMPLATES_DIR),
            autoescape=select_autoescape(["html"]),
        )

    def generate(self, invoice, db: Optional[Session] = None) -> bytes:
        """
        Accept a fully-loaded Invoice ORM object (relations eager-loaded).
        Returns raw PDF bytes.
        """
        html_content: Optional[str] = None

        if db is not None:
            try:
                from src.apps.site_templates import crud as site_template_crud

                tmpl = site_template_crud.get_template_by_key(db, "invoice_pdf")
                if tmpl and tmpl.is_active and tmpl.body_html:
                    sb_env = jinja2.sandbox.SandboxedEnvironment(
                        loader=jinja2.BaseLoader()
                    )
                    db_ctx = self._build_db_context(invoice)
                    html_content = sb_env.from_string(tmpl.body_html).render(**db_ctx)
            except Exception as exc:
                logger.warning(
                    "DB-first template rendering failed for invoice_pdf, falling back to disk: %s",
                    exc,
                )
                html_content = None

        if html_content is None:
            disk_template = self.env.get_template("invoice_pdf.html")
            context = self._build_context(invoice)
            html_content = disk_template.render(**context)

        pdf_bytes = HTML(
            string=html_content,
            base_url=TEMPLATES_DIR,
            url_fetcher=_safe_url_fetcher,
        ).write_pdf()
        return pdf_bytes

    def generate_pdf(self, invoice, merchant=None, customer=None, line_items=None, db: Optional[Session] = None) -> bytes:
        """Alias for generate() — accepts optional separate args but delegates to generate()."""
        return self.generate(invoice, db=db)

    def _build_db_context(self, invoice) -> dict:
        ctx = self._build_context(invoice)
        return {
            "primary_color": self._get_brand_color(invoice),
            "accent_color": "#FB7585",
            "text_color": "#252525",
            "merchant_name": ctx.get("merchant_name", ""),
            "customer_name": ctx.get("payer_name", ""),
            "customer_email": ctx.get("payer_email", ""),
            "customer_address": ctx.get("payer_address", ""),
            "invoice_number": ctx.get("invoice_literal", ""),
            "issue_date": ctx.get("issue_date", ""),
            "due_date": ctx.get("due_date", ""),
            "status": ctx.get("status", ""),
            "line_items": ctx.get("line_items", []),
            "subtotal": ctx.get("subtotal", 0.0),
            "discount": next(
                (a["amount"] for a in ctx.get("adjustments", []) if a.get("type") == "discount"),
                0.0,
            ),
            "tax": next(
                (a["amount"] for a in ctx.get("adjustments", []) if a.get("type") == "tax"),
                0.0,
            ),
            "amount": ctx.get("total", 0.0),
            "notes": ctx.get("notes", ""),
        }

    # ── Context builder ──────────────────────────────────────────────────────

    def _build_context(self, invoice) -> dict:
        return {
            "invoice_literal": invoice.invoice_literal or "",
            "issue_date": self._fmt_date(getattr(invoice, "billing_date", None) or invoice.created_at),
            "due_date": self._fmt_date(invoice.due_date),
            "status": getattr(invoice, "status_text", None) or str(invoice.status or ""),
            "now": datetime.utcnow().strftime("%B %d, %Y"),

            # Merchant
            "merchant_name": getattr(invoice.merchant, "name", "") or "",
            **self._get_business_urls(invoice),
            "merchant_address": self._get_merchant_address(invoice),
            "merchant_logo_url": self._get_logo_url(invoice),
            "brand_color": self._get_brand_color(invoice),

            # Payer (CustomerContact) → fallback to Customer
            "payer_name": self._get_payer_name(invoice),
            "payer_email": self._get_payer_email(invoice),
            "payer_phone": self._get_payer_phone(invoice),
            "payer_address": self._get_customer_address(invoice),
            "account_number": self._safe_get(invoice, "customer.uin") or "N/A",

            # Invoice metadata
            "reference": getattr(invoice, "reference", None) or "N/A",
            "payment_type": self._fmt_enum(self._safe_get(invoice, "payment_request.payment_frequency")),
            "auth_type": self._fmt_enum(self._safe_get(invoice, "payment_request.authorization_type")),

            # Line items
            "line_items": self._get_line_items(invoice),

            # Cost breakdown (DB stores amounts in cents → convert to dollars)
            "subtotal": self._calc_subtotal(invoice),
            "adjustments": self._get_adjustments(invoice),
            "total": self._c(invoice.amount),
            "currency": "$",
            "amount_paid": self._c(invoice.paid_amount),
            "amount_due": self._calc_amount_due(invoice),

            # Payment URL
            "payment_url": self._safe_get(invoice, "payment_request.payment_url"),
            "terms": getattr(invoice, "terms", None),
            "notes": getattr(invoice, "notes", None),
        }

    # ── Field helpers ────────────────────────────────────────────────────────

    @staticmethod
    def _c(cents) -> float:
        """Convert a cents integer (as stored in DB) to a dollar float."""
        if cents is None:
            return 0.0
        return round(cents / 100, 2)

    def _fmt_date(self, value) -> str:
        if value is None:
            return "—"
        if isinstance(value, str):
            return value
        try:
            return value.strftime("%B %d, %Y")
        except Exception:
            return str(value)

    def _safe_get(self, obj, path: str):
        """Safely traverse a dotted attribute path. Returns None if any step is None."""
        current = obj
        for part in path.split("."):
            if current is None:
                return None
            current = getattr(current, part, None)
        return current

    def _get_merchant_address(self, invoice) -> str:
        merchant = getattr(invoice, "merchant", None)
        if not merchant:
            return ""
        addr = getattr(merchant, "default_address", None)
        if not addr:
            return ""
        parts = [
            getattr(addr, "address_line_1", "") or getattr(addr, "address_line1", ""),
            getattr(addr, "city", ""),
            getattr(addr, "state", ""),
            getattr(addr, "zip_code", "") or getattr(addr, "zipcode", ""),
        ]
        return ", ".join(p for p in parts if p)

    def _get_payer_name(self, invoice) -> str:
        payer = getattr(invoice, "payer", None)
        if payer:
            first = getattr(payer, "first_name", "") or ""
            last = getattr(payer, "last_name", "") or ""
            name = f"{first} {last}".strip()
            if name:
                return name
        customer = getattr(invoice, "customer", None)
        if customer:
            first = getattr(customer, "first_name", "") or ""
            last = getattr(customer, "last_name", "") or ""
            name = f"{first} {last}".strip()
            if name:
                return name
            biz = getattr(customer, "business_legal_name", "") or ""
            if biz:
                return biz
        return "—"

    def _get_payer_email(self, invoice) -> str:
        payer = getattr(invoice, "payer", None)
        if payer and getattr(payer, "email", None):
            return payer.email
        customer = getattr(invoice, "customer", None)
        if customer and getattr(customer, "email", None):
            return customer.email
        return ""

    def _get_payer_phone(self, invoice) -> str:
        customer = getattr(invoice, "customer", None)
        if customer and getattr(customer, "phone", None):
            return customer.phone
        return ""

    def _get_customer_address(self, invoice) -> str:
        customer = getattr(invoice, "customer", None)
        if not customer:
            return ""
        addr = getattr(customer, "billing_address", None) or getattr(customer, "default_address", None)
        if not addr:
            return ""
        parts = [
            getattr(addr, "address_line_1", "") or getattr(addr, "address_line1", ""),
            getattr(addr, "city", ""),
            getattr(addr, "state", ""),
            getattr(addr, "zip_code", "") or getattr(addr, "zipcode", ""),
        ]
        return ", ".join(p for p in parts if p)

    def _get_line_items(self, invoice) -> list:
        items = getattr(invoice, "invoice_line_items", None) or []
        result = []
        for item in items:
            qty = getattr(item, "quantity", 1) or 1
            unit_price = self._c(getattr(item, "unit_price", 0) or 0)
            total = round(qty * unit_price, 2)
            result.append({
                "description": getattr(item, "title", None) or getattr(item, "description", "") or "",
                "quantity": qty,
                "unit_price": unit_price,
                "tax_rate": getattr(item, "tax", None),
                "total": total,
            })
        if not result:
            # No line items — treat as a quick charge
            amount = self._c(invoice.amount)
            result.append({
                "description": "Quick Charge",
                "quantity": 1,
                "unit_price": amount,
                "tax_rate": None,
                "total": amount,
            })
        return result

    def _calc_subtotal(self, invoice) -> float:
        items = getattr(invoice, "invoice_line_items", None) or []
        total = 0.0
        for item in items:
            qty = getattr(item, "quantity", 1) or 1
            unit_price = self._c(getattr(item, "unit_price", 0) or 0)
            total += qty * unit_price
        if not items:
            # Quick charge — subtotal equals the invoice amount
            return self._c(invoice.amount)
        return round(total, 2)

    def _get_adjustments(self, invoice) -> list:
        result = []
        # Sales Tax — always shown (matches UI behaviour)
        tax_fee = getattr(invoice, "tax_fee", None) or 0
        result.append({"label": "Sales Tax", "amount": self._c(tax_fee), "type": "tax", "percentage": "0.00 %"})
        # Shipping — always shown
        shipping_fee = getattr(invoice, "shipping_fee", None) or 0
        result.append({"label": "Shipping", "amount": self._c(shipping_fee), "type": "shipping", "percentage": ""})
        # Tip — always shown
        tip = getattr(invoice, "tip", None) or 0
        result.append({"label": "TIP", "amount": self._c(tip), "type": "tip", "percentage": "0.00 %"})
        # Discount — always shown; prefer adjustment relation, then invoice-level field
        adj = getattr(invoice, "adjustment", None)
        disc_amount = 0
        disc_name = "Discount"
        if adj and getattr(adj, "is_discounted", False):
            disc_amount = getattr(adj, "discount_amount", 0) or 0
            disc_name = getattr(adj, "discount_name", None) or "Discount"
        else:
            disc_amount = getattr(invoice, "discount", None) or 0
        result.append({"label": disc_name, "amount": self._c(disc_amount), "type": "discount", "percentage": ""})
        # Surcharge — only shown when present
        surcharge = getattr(invoice, "surcharge", None)
        if surcharge:
            result.append({"label": "Surcharge", "amount": self._c(surcharge), "type": "surcharge", "percentage": ""})
        return result

    def _get_business_urls(self, invoice) -> dict:
        """Read support email, phone, website from branding settings, respecting show toggles."""
        try:
            from src.apps.settings.helper import get_setting_with_default
            merchant = getattr(invoice, "merchant", None)
            if not merchant:
                raise ValueError("no merchant")
            mid = merchant.id

            def _get(key): return get_setting_with_default("Business URLs", key, mid) or ""
            def _show(key): return (_get(key) or "true").lower() != "false"

            email = _get("support_email") if _show("show_support_email") else ""
            phone = _get("support_phone") if _show("show_support_phone") else ""
            website = _get("website_url") if _show("show_website_url") else ""

            # Fall back to merchant model fields if branding settings are empty
            if not email:
                email = getattr(merchant, "email", "") or ""
            if not phone:
                phone = getattr(merchant, "phone", "") or ""
        except Exception:
            merchant = getattr(invoice, "merchant", None)
            email = getattr(merchant, "email", "") or "" if merchant else ""
            phone = getattr(merchant, "phone", "") or "" if merchant else ""
            website = ""
        return {"merchant_email": email, "merchant_phone": phone, "merchant_website": website}

    def _get_logo_url(self, invoice) -> str:
        """Read merchant logo URL from merchant_settings and resolve to data URI for WeasyPrint."""
        try:
            from src.apps.settings.helper import get_setting_with_default
            merchant = getattr(invoice, "merchant", None)
            if merchant:
                logo = get_setting_with_default("Logo", "logo_url", merchant.id)
                if logo:
                    return self._resolve_logo(logo)
        except Exception:
            pass
        return ""

    @staticmethod
    def _resolve_logo(url: str) -> str:
        """Convert a logo URL to a base64 data URI so WeasyPrint reads it from disk."""
        import base64, mimetypes, os
        try:
            from src.core.config import settings as _s
            server_host = str(_s.SERVER_HOST).rstrip("/")
            static_path = str(_s.STATIC_FILES_PATH).rstrip("/")
            uploads_dir = str(_s.UPLOADS_DIR)
            if url.startswith(server_host + static_path):
                rel = url[len(server_host + static_path):]
                file_path = os.path.join(uploads_dir, rel.lstrip("/"))
                if os.path.isfile(file_path):
                    mime, _ = mimetypes.guess_type(file_path)
                    mime = mime or "image/png"
                    with open(file_path, "rb") as f:
                        data = base64.b64encode(f.read()).decode()
                    return f"data:{mime};base64,{data}"
        except Exception:
            pass
        return url

    def _get_brand_color(self, invoice) -> str:
        """Read merchant's primary color from merchant_settings. Falls back to default."""
        try:
            from src.apps.settings.helper import get_setting_with_default
            merchant = getattr(invoice, "merchant", None)
            if merchant:
                color = get_setting_with_default("Colors", "primary_color", merchant.id)
                if color:
                    return color
        except Exception:
            pass
        return "#1a56db"

    def _calc_amount_due(self, invoice) -> float:
        total = self._c(invoice.amount)
        paid = self._c(invoice.paid_amount)
        return round(max(total - paid, 0), 2)

    @staticmethod
    def _fmt_enum(value) -> str:
        """Convert snake_case enum string to Title Case (e.g. 'request_auth' → 'Request Auth')."""
        if not value:
            return "—"
        return str(value).replace("_", " ").title()


# Module-level singleton
invoice_pdf_generator = InvoicePDFGenerator()
