import json
from typing import Any, cast

import httpx
from rich.console import Console

from crewai_cli.authentication.main import Oauth2Settings, ProviderFactory
from crewai_cli.command import BaseCommand
from crewai_cli.settings.main import SettingsCommand
from crewai_cli.version import get_crewai_version


console = Console()


class EnterpriseConfigureCommand(BaseCommand):
    def __init__(self) -> None:
        super().__init__()
        self.settings_command = SettingsCommand()

    def configure(self, enterprise_url: str) -> None:
        try:
            enterprise_url = enterprise_url.rstrip("/")

            oauth_config = self._fetch_oauth_config(enterprise_url)

            self._update_oauth_settings(enterprise_url, oauth_config)

            console.print(
                f"✅ Successfully configured CrewAI AMP with OAuth2 settings from {enterprise_url}",
                style="bold green",
            )

        except Exception as e:
            console.print(
                f"❌ Failed to configure Enterprise settings: {e!s}", style="bold red"
            )
            raise SystemExit(1) from e

    def _fetch_oauth_config(self, enterprise_url: str) -> dict[str, Any]:
        oauth_endpoint = f"{enterprise_url}/auth/parameters"

        try:
            console.print(f"🔄 Fetching OAuth2 configuration from {oauth_endpoint}...")
            headers = {
                "Content-Type": "application/json",
                "User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
                "X-Crewai-Version": get_crewai_version(),
            }
            response = httpx.get(oauth_endpoint, timeout=30, headers=headers)
            response.raise_for_status()

            try:
                oauth_config = response.json()
            except json.JSONDecodeError as e:
                raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e

            self._validate_oauth_config(oauth_config)

            console.print(
                "✅ Successfully retrieved OAuth2 configuration", style="green"
            )
            return cast(dict[str, Any], oauth_config)

        except httpx.HTTPError as e:
            raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e
        except Exception as e:
            raise ValueError(f"Error fetching OAuth2 configuration: {e!s}") from e

    def _update_oauth_settings(
        self, enterprise_url: str, oauth_config: dict[str, Any]
    ) -> None:
        try:
            config_mapping = {
                "enterprise_base_url": enterprise_url,
                "oauth2_provider": oauth_config["provider"],
                "oauth2_audience": oauth_config["audience"],
                "oauth2_client_id": oauth_config["device_authorization_client_id"],
                "oauth2_domain": oauth_config["domain"],
                "oauth2_extra": oauth_config["extra"],
            }

            console.print("🔄 Updating local OAuth2 configuration...")

            for key, value in config_mapping.items():
                self.settings_command.set(key, value)
                console.print(f"  ✓ Set {key}: {value}", style="dim")

        except Exception as e:
            raise ValueError(f"Failed to update OAuth2 settings: {e!s}") from e

    def _validate_oauth_config(self, oauth_config: dict[str, Any]) -> None:
        required_fields = [
            "audience",
            "domain",
            "device_authorization_client_id",
            "provider",
            "extra",
        ]

        missing_basic_fields = [
            field for field in required_fields if field not in oauth_config
        ]
        missing_provider_specific_fields = [
            field
            for field in self._get_provider_specific_fields(oauth_config["provider"])
            if field not in oauth_config.get("extra", {})
        ]

        if missing_basic_fields:
            raise ValueError(
                f"Missing required fields in OAuth2 configuration: [{', '.join(missing_basic_fields)}]"
            )

        if missing_provider_specific_fields:
            raise ValueError(
                f"Missing authentication provider required fields in OAuth2 configuration: [{', '.join(missing_provider_specific_fields)}] (Configured provider: '{oauth_config['provider']}')"
            )

    def _get_provider_specific_fields(self, provider_name: str) -> list[str]:
        provider = ProviderFactory.from_settings(
            Oauth2Settings(provider=provider_name, client_id="dummy", domain="dummy")
        )

        return provider.get_required_fields()
