"""tax001 - Add TaxJar integration: tax_rate_cache table + checkout/invoice tax fields

Creates:
  - tax_rate_cache table (HWTAX-106)
Adds columns:
  - checkouts.tax_percent, checkouts.tax_type (HWTAX-105)
  - checkout_line_items.tax, checkout_line_items.tax_amount (HWTAX-105)
  - invoices_line_items.tax_amount (HWTAX-105)

Revision ID: tax001_taxjar_integration
Revises: chk003_tip_allow_custom
Create Date: 2026-04-02
"""
from typing import Union, Sequence

import sqlalchemy as sa
from alembic import op

revision: str = 'tax001_taxjar_integration'
down_revision: Union[str, Sequence[str], None] = 'chk003_tip_allow_custom'
branch_labels = None
depends_on = None


def upgrade() -> None:
    # ── tax_rate_cache ────────────────────────────────────────────────────────
    op.create_table(
        'tax_rate_cache',
        sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True),
        sa.Column('zip', sa.String(10), nullable=False, index=True),
        sa.Column('city', sa.String(100), nullable=True),
        sa.Column('state', sa.String(10), nullable=True),
        sa.Column('country', sa.String(10), nullable=False, server_default='US'),
        sa.Column('label', sa.String(255), nullable=False),
        sa.Column('combined_rate', sa.Float(), nullable=False),
        sa.Column('state_rate', sa.Float(), nullable=False, server_default='0'),
        sa.Column('county_rate', sa.Float(), nullable=False, server_default='0'),
        sa.Column('city_rate', sa.Float(), nullable=False, server_default='0'),
        sa.Column('special_district_rate', sa.Float(), nullable=False, server_default='0'),
        sa.Column('cached_at', sa.DateTime(), nullable=False),
        sa.Column('expires_at', sa.DateTime(), nullable=False),
        sa.UniqueConstraint('zip', 'country', name='uq_tax_rate_zip_country'),
    )
    # Note: index on 'zip' is created inline via index=True above; no separate create_index needed.

    # ── checkouts: tax fields ─────────────────────────────────────────────────
    op.add_column('checkouts', sa.Column('tax_percent', sa.Float(), nullable=True, server_default='0'))
    op.add_column('checkouts', sa.Column('tax_type', sa.String(30), nullable=True, server_default='notax'))

    # ── checkout_line_items: tax fields ───────────────────────────────────────
    # tax and tax_amount may already exist from a prior migration — add only if absent
    conn = op.get_bind()
    inspector = sa.inspect(conn)
    existing_cols = [c['name'] for c in inspector.get_columns('checkout_line_items')]
    if 'tax' not in existing_cols:
        op.add_column('checkout_line_items', sa.Column('tax', sa.Float(), nullable=True, server_default='0'))
    if 'tax_amount' not in existing_cols:
        op.add_column('checkout_line_items', sa.Column('tax_amount', sa.Float(), nullable=True, server_default='0'))

    # ── invoices_line_items: tax_amount ───────────────────────────────────────
    existing_inv_cols = [c['name'] for c in inspector.get_columns('invoices_line_items')]
    if 'tax_amount' not in existing_inv_cols:
        op.add_column('invoices_line_items', sa.Column('tax_amount', sa.Float(), nullable=True, server_default='0'))


def downgrade() -> None:
    op.drop_column('invoices_line_items', 'tax_amount')
    op.drop_column('checkout_line_items', 'tax_amount')
    op.drop_column('checkout_line_items', 'tax')
    op.drop_column('checkouts', 'tax_type')
    op.drop_column('checkouts', 'tax_percent')
    op.drop_table('tax_rate_cache')
