"""
Billing date computation utilities for the Subscriptions domain.
"""

from datetime import datetime, timedelta, timezone
from typing import List, Optional, TYPE_CHECKING

from dateutil.relativedelta import relativedelta


def compute_next_billing_date(
    interval: str,
    interval_value: int,
    from_date: datetime,
) -> datetime:
    """
    Advance from_date by one billing period.

    Args:
        interval: "day" | "week" | "month" | "quarter" | "year"
        interval_value: multiplier (e.g. 2 for bi-weekly)
        from_date: the reference datetime to advance from

    Returns:
        New datetime exactly one billing period ahead.
    """
    interval = interval.lower()
    if interval == "day":
        return from_date + timedelta(days=interval_value)
    elif interval == "week":
        return from_date + timedelta(weeks=interval_value)
    elif interval == "month":
        return from_date + relativedelta(months=interval_value)
    elif interval == "quarter":
        return from_date + relativedelta(months=3 * interval_value)
    elif interval == "year":
        return from_date + relativedelta(years=interval_value)
    # fallback: treat unknown as monthly
    return from_date + relativedelta(months=interval_value)


def compute_projected_dates(
    interval: str,
    interval_value: int,
    from_date: datetime,
    count: int,
    end_type: Optional[str] = None,
    end_date: Optional[datetime] = None,
    pay_until_count: Optional[int] = None,
    already_generated: int = 0,
) -> List[datetime]:
    """
    Return up to `count` future billing dates, respecting end conditions.

    Args:
        interval: billing interval type
        interval_value: multiplier
        from_date: first projected date (next_billing_date)
        count: maximum number of dates to return
        end_type: "date" | "until_count" | "until_cancelled" | None
        end_date: when end_type == "date"
        pay_until_count: total payments when end_type == "until_count"
        already_generated: invoices already generated (used for count check)
    """
    dates: List[datetime] = []
    current = from_date
    generated = already_generated

    # Normalize end_date to naive UTC so it can be compared with naive datetimes
    if end_date is not None and end_date.tzinfo is not None:
        end_date = end_date.astimezone(timezone.utc).replace(tzinfo=None)

    for _ in range(count):
        _current = current.astimezone(timezone.utc).replace(tzinfo=None) if current.tzinfo else current
        if end_type == "date" and end_date and _current > end_date:
            break
        if end_type == "until_count" and pay_until_count and generated >= pay_until_count:
            break
        dates.append(current)
        current = compute_next_billing_date(interval, interval_value, current)
        generated += 1

    return dates


def get_interval_label(interval: str, interval_value: int) -> str:
    """
    Return human-readable billing interval label.

    Examples:
        month / 1  → "Monthly"
        week / 2   → "Every 2 Weeks"
        month / 3  → "Every 3 Months"
    """
    interval = interval.lower()
    singular = {
        "day": "Daily",
        "week": "Weekly",
        "month": "Monthly",
        "quarter": "Quarterly",
        "year": "Yearly",
    }
    plural = {
        "day": "Days",
        "week": "Weeks",
        "month": "Months",
        "quarter": "Quarters",
        "year": "Years",
    }
    if interval_value == 1:
        return singular.get(interval, f"Every 1 {interval.capitalize()}")
    return f"Every {interval_value} {plural.get(interval, interval.capitalize() + 's')}"
