"""
Role Permissions CRUD — PRD-008 RBAC Multi-User Merchant Accounts

All queries use SQLAlchemy 2.0 select() syntax.
"""

from __future__ import annotations

import logging
from datetime import datetime, timezone
from typing import Optional

from sqlalchemy import select, update, delete, and_
from sqlalchemy.orm import Session

from src.apps.role_permissions.models.role import Role
from src.apps.role_permissions.models.permission import Permission
from src.apps.role_permissions.models.user_role import UserRole
from src.apps.role_permissions.models.merchant_invite import MerchantInvite
from src.apps.role_permissions.models.merchant_audit_log import MerchantAuditLog

logger = logging.getLogger(__name__)

# ─── System role slugs ───────────────────────────────────────────────────────

_SYSTEM_ROLES = [
    {
        "label": "Owner",
        "slug": "owner",
        "description": "Full access to all features. Cannot be deleted or edited.",
        "is_system": True,
        "is_default": False,
        "role_rank": 3,
    },
    {
        "label": "Admin",
        "slug": "admin",
        "description": "Manage team members, roles, and most settings.",
        "is_system": True,
        "is_default": False,
        "role_rank": 2,
    },
    {
        "label": "Staff",
        "slug": "staff",
        "description": "Basic access — process payments, view reports.",
        "is_system": True,
        "is_default": True,
        "role_rank": 1,
    },
]

# ─── Permission CRUD ─────────────────────────────────────────────────────────


def get_all_permissions(db: Session) -> list[Permission]:
    """Return all permissions ordered by display_order."""
    stmt = select(Permission).order_by(Permission.display_order)
    return list(db.execute(stmt).scalars().all())


# ─── Role CRUD ───────────────────────────────────────────────────────────────


def get_role_by_id(db: Session, role_id: int) -> Optional[Role]:
    stmt = select(Role).where(Role.id == role_id)
    return db.execute(stmt).scalar_one_or_none()


def get_role_by_slug(db: Session, slug: str) -> Optional[Role]:
    stmt = select(Role).where(Role.slug == slug)
    return db.execute(stmt).scalar_one_or_none()


def get_roles_for_merchant(
    db: Session, merchant_id: int, caller_rank: int
) -> list[Role]:
    """
    Return system roles (merchant_id IS NULL) + custom roles for this merchant,
    filtered to roles whose rank <= caller_rank (callers cannot see higher-rank roles).
    """
    stmt = (
        select(Role)
        .where(
            and_(
                (Role.merchant_id == merchant_id) | (Role.merchant_id.is_(None)),
                Role.role_rank <= caller_rank,
            )
        )
        .order_by(Role.role_rank.desc())
    )
    return list(db.execute(stmt).scalars().all())


def create_custom_role(
    db: Session,
    merchant_id: int,
    label: str,
    permission_ids: list[int],
    created_by_id: int,
    role_rank: int = 1,
) -> Role:
    """Create a merchant-scoped custom role with the given permissions."""
    import re

    base_slug = f"custom_{re.sub(r'[^a-z0-9]+', '_', label.lower()).strip('_')}_{merchant_id}"
    # Ensure slug uniqueness by appending a counter if the base slug is taken
    slug = base_slug
    counter = 1
    while db.execute(select(Role).where(Role.slug == slug)).scalar_one_or_none():
        slug = f"{base_slug}_{counter}"
        counter += 1

    permissions = []
    if permission_ids:
        permissions = list(
            db.execute(
                select(Permission).where(Permission.id.in_(permission_ids))
            ).scalars().all()
        )

    role = Role(
        label=label,
        slug=slug,
        description=None,
        is_system=False,
        is_default=False,
        role_rank=role_rank,
        merchant_id=merchant_id,
        created_by_id=created_by_id,
    )
    role.permissions = permissions
    db.add(role)
    db.flush()
    return role


def update_role(
    db: Session,
    role_id: int,
    label: Optional[str],
    permission_ids: Optional[list[int]],
    role_rank: Optional[int] = None,
) -> Role:
    """Update label, permissions, and/or role_rank on a role."""
    role = get_role_by_id(db, role_id)
    if label is not None:
        role.label = label
    if role_rank is not None:
        role.role_rank = role_rank
    if permission_ids is not None:
        permissions = list(
            db.execute(
                select(Permission).where(Permission.id.in_(permission_ids))
            ).scalars().all()
        )
        role.permissions = permissions
    db.flush()
    return role


def delete_role(db: Session, role_id: int) -> None:
    """
    Delete a custom role.
    Clears all FK references before deleting the role row:
      1. Nulls out MerchantAuditLog.target_role_id (nullable FK, no ondelete set).
      2. Cancels pending MerchantInvites for this role (role_id is NOT NULL,
         so rows must be cancelled rather than nulled).
      3. Reassigns UserRole members to Staff.
      4. Deletes the role (roles_permissions cascade handles the join table).
    """
    now = datetime.now(timezone.utc).replace(tzinfo=None)

    # 1. Null out audit log references
    db.execute(
        update(MerchantAuditLog)
        .where(MerchantAuditLog.target_role_id == role_id)
        .values(target_role_id=None)
    )

    # 2. Cancel all invites assigned this role (pending and accepted alike).
    # role_id is NOT NULL so rows cannot be nulled — they must be cancelled
    # to release the FK before the role row is deleted.
    db.execute(
        update(MerchantInvite)
        .where(
            MerchantInvite.role_id == role_id,
            MerchantInvite.cancelled_at.is_(None),
        )
        .values(cancelled_at=now)
    )

    # 3. Reassign members with this role to Staff
    staff_role = get_role_by_slug(db, "staff")
    if staff_role:
        db.execute(
            update(UserRole)
            .where(UserRole.role_id == role_id)
            .values(role_id=staff_role.id)
        )
        # Also remap accepted/cancelled invite rows — role_id is NOT NULL so they
        # still hold a FK reference after the cancel step above.
        db.execute(
            update(MerchantInvite)
            .where(MerchantInvite.role_id == role_id)
            .values(role_id=staff_role.id)
        )

    # 4. Delete the role (roles_permissions join rows cascade automatically)
    db.execute(delete(Role).where(Role.id == role_id))
    db.flush()


# ─── UserRole CRUD ───────────────────────────────────────────────────────────


def get_user_role(
    db: Session, user_id: int, merchant_id: int
) -> Optional[UserRole]:
    stmt = select(UserRole).where(
        UserRole.user_id == user_id,
        UserRole.merchant_id == merchant_id,
    )
    return db.execute(stmt).scalar_one_or_none()


def assign_role_to_user(
    db: Session, user_id: int, role_id: int, merchant_id: int
) -> UserRole:
    """Upsert: create UserRole if not present, update role_id if already present."""
    existing = get_user_role(db, user_id, merchant_id)
    if existing:
        existing.role_id = role_id
        db.flush()
        return existing

    user_role = UserRole(
        user_id=user_id,
        role_id=role_id,
        merchant_id=merchant_id,
        is_primary=True,
    )
    db.add(user_role)
    db.flush()
    return user_role


def remove_member(db: Session, merchant_id: int, user_id: int) -> None:
    """
    Remove a user from a merchant account.
    Revokes all active AuthSessions for the user and deletes the UserRole row.
    """
    from src.apps.auth.models.auth_session import AuthSession

    now = datetime.now(timezone.utc)
    db.execute(
        update(AuthSession)
        .where(
            AuthSession.user_id == user_id,
            AuthSession.is_active == True,
            AuthSession.is_revoked == False,
        )
        .values(
            is_active=False,
            is_revoked=True,
            revoked_at=now,
        )
    )
    db.execute(
        delete(UserRole).where(
            UserRole.user_id == user_id,
            UserRole.merchant_id == merchant_id,
        )
    )
    db.flush()


def get_members_for_merchant(
    db: Session, merchant_id: int, caller_rank: int
) -> list[UserRole]:
    """
    Return all UserRole rows for the merchant, filtered to those with
    role_rank <= caller_rank so callers cannot see higher-privileged members.
    """
    stmt = (
        select(UserRole)
        .join(Role, UserRole.role_id == Role.id)
        .where(
            UserRole.merchant_id == merchant_id,
            Role.role_rank <= caller_rank,
        )
        .order_by(Role.role_rank.desc())
    )
    return list(db.execute(stmt).scalars().all())


# ─── Permission lookup ────────────────────────────────────────────────────────


def get_user_permissions_from_db(
    db: Session, user_id: int, merchant_id: int
) -> set[str]:
    """
    Return the set of permission slugs for a user in a merchant context.
    Queries via UserRole → Role → roles_permissions → Permission.
    """
    user_role = get_user_role(db, user_id, merchant_id)
    if not user_role or not user_role.role:
        return set()
    return {p.slug for p in user_role.role.permissions}


# ─── Invite CRUD ─────────────────────────────────────────────────────────────


def get_invite_by_id(
    db: Session, invite_id: int, merchant_id: int
) -> Optional[MerchantInvite]:
    stmt = select(MerchantInvite).where(
        MerchantInvite.id == invite_id,
        MerchantInvite.merchant_id == merchant_id,
    )
    return db.execute(stmt).scalar_one_or_none()


def get_invite_by_token_hash(
    db: Session, token_hash: str
) -> Optional[MerchantInvite]:
    stmt = select(MerchantInvite).where(MerchantInvite.token_hash == token_hash)
    return db.execute(stmt).scalar_one_or_none()


def get_pending_invites(db: Session, merchant_id: int) -> list[MerchantInvite]:
    stmt = select(MerchantInvite).where(
        MerchantInvite.merchant_id == merchant_id,
        MerchantInvite.accepted_at.is_(None),
        MerchantInvite.cancelled_at.is_(None),
    )
    return list(db.execute(stmt).scalars().all())


def create_invite(
    db: Session,
    merchant_id: int,
    email: str,
    role_id: int,
    token_hash: str,
    invited_by: int,
    expires_at: datetime,
) -> MerchantInvite:
    invite = MerchantInvite(
        merchant_id=merchant_id,
        email=email,
        role_id=role_id,
        token_hash=token_hash,
        invited_by=invited_by,
        expires_at=expires_at,
        accepted_at=None,
        cancelled_at=None,
    )
    db.add(invite)
    db.flush()
    return invite


def cancel_invite(db: Session, invite_id: int, merchant_id: int) -> None:
    now = datetime.now(timezone.utc).replace(tzinfo=None)
    db.execute(
        update(MerchantInvite)
        .where(
            MerchantInvite.id == invite_id,
            MerchantInvite.merchant_id == merchant_id,
        )
        .values(cancelled_at=now)
    )
    db.flush()


def accept_invite(db: Session, invite_id: int) -> bool:
    """
    Atomically mark an invite as accepted.
    The WHERE clause guards against double-acceptance under concurrent requests:
    only the first request that finds accepted_at IS NULL will produce rowcount=1.

    Returns:
        True  — this call claimed the invite (rowcount == 1).
        False — invite was already accepted by a concurrent request (rowcount == 0).
    """
    now = datetime.now(timezone.utc).replace(tzinfo=None)
    result = db.execute(
        update(MerchantInvite)
        .where(
            MerchantInvite.id == invite_id,
            MerchantInvite.accepted_at.is_(None),   # atomic guard
            MerchantInvite.cancelled_at.is_(None),  # also guard against cancelled
        )
        .values(accepted_at=now)
    )
    db.flush()
    return result.rowcount == 1


# ─── Audit log ───────────────────────────────────────────────────────────────


def write_audit_log(
    db: Session,
    merchant_id: int,
    actor_id: int,
    action: str,
    target_user_id: Optional[int] = None,
    target_role_id: Optional[int] = None,
    metadata: Optional[dict] = None,
) -> MerchantAuditLog:
    entry = MerchantAuditLog(
        merchant_id=merchant_id,
        actor_user_id=actor_id,
        action=action,
        target_user_id=target_user_id,
        target_role_id=target_role_id,
        metadata_=metadata,
    )
    db.add(entry)
    db.flush()
    return entry


# ─── System seed ─────────────────────────────────────────────────────────────


def seed_system_roles(db: Session) -> None:
    """
    Idempotent seed for the three system roles (Owner, Admin, Staff).
    Safe to call multiple times — skips any role whose slug already exists.
    """
    for role_def in _SYSTEM_ROLES:
        existing = get_role_by_slug(db, role_def["slug"])
        if existing:
            continue
        role = Role(
            label=role_def["label"],
            slug=role_def["slug"],
            description=role_def["description"],
            is_system=role_def["is_system"],
            is_default=role_def["is_default"],
            role_rank=role_def["role_rank"],
            merchant_id=None,
            created_by_id=None,
        )
        db.add(role)
    db.flush()
