"""
Rerun token checks for recently EXPIRED owner/editor UserAccount refresh tokens.

Run from the repo root:
    uv run python manage.py shell < /tmp/rerun_recent_expired_token_checks.py

Scope:
    - role in ("editor", "owner")
    - status = "EXPIRED"
    - last_token_check_time in the last 5 rolling days

This intentionally does not call the shared validator because that path can also
update Account names. This script only updates status and last_token_check_time
for the selected UserAccount rows.
"""

from collections import defaultdict

from django.db import transaction
from django.utils import timezone

from organization_auth.exceptions import TokenExpiredException, TokenInvalidException
from organization_auth.models import Account, UserAccount
from organization_auth.services import get_ad_accounts_list
from organization_auth.tasks.auth_tasks import (
    _build_token_validation_credentials,
    _extract_accessible_accounts,
    _normalize_account_id,
)


LOOKBACK_DAYS = 5
ROLES = ("editor", "owner")


def token_key(user_account):
    token_bytes = (
        user_account.refresh_token.tobytes()
        if hasattr(user_account.refresh_token, "tobytes")
        else bytes(user_account.refresh_token)
    )
    return user_account.account.platform, token_bytes.hex()


cutoff = timezone.now() - timezone.timedelta(days=LOOKBACK_DAYS)

user_accounts = list(
    UserAccount.objects.filter(
        role__in=ROLES,
        status=UserAccount.UserAccountStatus.EXPIRED,
        refresh_token__isnull=False,
        last_token_check_time__gte=cutoff,
    )
    .select_related("account", "user")
    .order_by("account__platform", "user_id", "account_id", "id")
)

groups = defaultdict(list)
for user_account in user_accounts:
    groups[token_key(user_account)].append(user_account)

print(
    f"Selected {len(user_accounts)} UserAccount rows across "
    f"{len(groups)} unique platform/token groups since {cutoff.isoformat()}."
)

summary = {
    "groups_checked": 0,
    "rows_checked": 0,
    "tokens_invalid": 0,
    "rows_set_active": 0,
    "rows_kept_expired": 0,
    "errors": 0,
}

for (platform, _token_hex), group in groups.items():
    checked_at = timezone.now()
    first_user_account = group[0]
    account_ids = [str(ua.account_id) for ua in group]
    user_account_ids = [str(ua.id) for ua in group]

    print(
        "Checking "
        f"platform={platform} rows={len(group)} "
        f"user_account_ids={user_account_ids} account_ids={account_ids}"
    )

    try:
        credentials_dict = _build_token_validation_credentials(
            first_user_account,
            group,
            platform,
        )
        account_list = get_ad_accounts_list(
            credentials_dict,
            platform,
            first_user_account.user,
            list_only=platform != Account.Platform.LINKEDIN,
        )
    except (TokenExpiredException, TokenInvalidException) as exc:
        with transaction.atomic():
            UserAccount.objects.filter(id__in=user_account_ids).update(
                status=UserAccount.UserAccountStatus.EXPIRED,
                last_token_check_time=checked_at,
            )

        summary["groups_checked"] += 1
        summary["rows_checked"] += len(group)
        summary["tokens_invalid"] += 1
        summary["rows_kept_expired"] += len(group)
        print(f"  token invalid/expired; kept EXPIRED: {exc}")
        continue
    except Exception as exc:
        summary["errors"] += 1
        print(f"  ERROR: token check failed without status changes: {exc}")
        continue

    accessible_account_ids, _account_name_map = _extract_accessible_accounts(
        platform,
        account_list,
    )
    normalized_accessible_account_ids = {
        _normalize_account_id(platform, account_id)
        for account_id in accessible_account_ids
        if account_id
    }

    active_ids = []
    expired_ids = []
    for user_account in group:
        normalized_account_id = _normalize_account_id(
            platform,
            user_account.account.ad_account_id,
        )
        if normalized_account_id in normalized_accessible_account_ids:
            active_ids.append(user_account.id)
        else:
            expired_ids.append(user_account.id)

    with transaction.atomic():
        if active_ids:
            UserAccount.objects.filter(id__in=active_ids).update(
                status=UserAccount.UserAccountStatus.ACTIVE,
                last_token_check_time=checked_at,
            )
        if expired_ids:
            UserAccount.objects.filter(id__in=expired_ids).update(
                status=UserAccount.UserAccountStatus.EXPIRED,
                last_token_check_time=checked_at,
            )

    summary["groups_checked"] += 1
    summary["rows_checked"] += len(group)
    summary["rows_set_active"] += len(active_ids)
    summary["rows_kept_expired"] += len(expired_ids)
    print(f"  token valid; set ACTIVE={len(active_ids)} kept EXPIRED={len(expired_ids)}")

print(f"Done: {summary}")
