from collections import defaultdict
from collections.abc import Sequence
import json
import os
from pathlib import Path
import time
from typing import Any

import certifi
import click
import httpx

from crewai_cli.constants import JSON_URL, MODELS, PROVIDERS


def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:
    """Presents a list of choices to the user and prompts them to select one.

    Args:
        prompt_message: The message to display to the user before presenting the choices.
        choices: A list of options to present to the user.

    Returns:
        The selected choice from the list, or None if the user chooses to quit.
    """

    provider_models = get_provider_data()
    if not provider_models:
        return None
    click.secho(prompt_message, fg="cyan")
    for idx, choice in enumerate(choices, start=1):
        click.secho(f"{idx}. {choice}", fg="cyan")
    click.secho("q. Quit", fg="cyan")

    while True:
        choice = click.prompt(
            "Enter the number of your choice or 'q' to quit", type=str
        )

        if choice.lower() == "q":
            return None

        try:
            selected_index = int(choice) - 1
            if 0 <= selected_index < len(choices):
                return choices[selected_index]
        except ValueError:
            pass

        click.secho(
            "Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
            fg="red",
        )


def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool:
    """Presents a list of providers to the user and prompts them to select one.

    Args:
        provider_models: A dictionary of provider models.

    Returns:
        The selected provider, None if user explicitly quits, or False if no selection.
    """
    predefined_providers = [p.lower() for p in PROVIDERS]
    all_providers = sorted(set(predefined_providers + list(provider_models.keys())))

    provider = select_choice(
        "Select a provider to set up:", [*predefined_providers, "other"]
    )
    if provider is None:  # User typed 'q'
        return None

    if provider == "other":
        provider = select_choice("Select a provider from the full list:", all_providers)
        if provider is None:  # User typed 'q'
            return None

    return provider.lower() if provider else False


def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None:
    """Presents a list of models for a given provider to the user and prompts them to select one.

    Args:
        provider: The provider for which to select a model.
        provider_models: A dictionary of provider models.

    Returns:
        The selected model, or None if the operation is aborted or an invalid selection is made.
    """
    predefined_providers = [p.lower() for p in PROVIDERS]

    if provider in predefined_providers:
        available_models = MODELS.get(provider, [])
    else:
        available_models = provider_models.get(provider, [])

    if not available_models:
        click.secho(f"No models available for provider '{provider}'.", fg="red")
        return None

    return select_choice(
        f"Select a model to use for {provider.capitalize()}:", available_models
    )


def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None:
    """Loads provider data from a cache file if it exists and is not expired.

    If the cache is expired or corrupted, it fetches the data from the web.

    Args:
        cache_file: The path to the cache file.
        cache_expiry: The cache expiry time in seconds.

    Returns:
        The loaded provider data or None if the operation fails.
    """
    current_time = time.time()
    if (
        cache_file.exists()
        and (current_time - cache_file.stat().st_mtime) < cache_expiry
    ):
        data = read_cache_file(cache_file)
        if data:
            return data
        click.secho(
            "Cache is corrupted. Fetching provider data from the web...", fg="yellow"
        )
    else:
        click.secho(
            "Cache expired or not found. Fetching provider data from the web...",
            fg="cyan",
        )
    return fetch_provider_data(cache_file)


def read_cache_file(cache_file: Path) -> dict[str, Any] | None:
    """Reads and returns the JSON content from a cache file.

    Args:
        cache_file: The path to the cache file.

    Returns:
        The JSON content of the cache file or None if the JSON is invalid.
    """
    try:
        with open(cache_file, "r") as f:
            data: dict[str, Any] = json.load(f)
            return data
    except json.JSONDecodeError:
        return None


def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
    """Fetches provider data from a specified URL and caches it to a file.

    Args:
        cache_file: The path to the cache file.

    Returns:
        The fetched provider data or None if the operation fails.
    """
    ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()

    try:
        with httpx.stream("GET", JSON_URL, timeout=60, verify=ssl_config) as response:
            response.raise_for_status()
            data = download_data(response)
            with open(cache_file, "w") as f:
                json.dump(data, f)
            return data
    except httpx.HTTPError as e:
        click.secho(f"Error fetching provider data: {e}", fg="red")
    except json.JSONDecodeError:
        click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
    return None


def download_data(response: httpx.Response) -> dict[str, Any]:
    """Downloads data from a given HTTP response and returns the JSON content.

    Args:
        response: The HTTP response object.

    Returns:
        The JSON content of the response.
    """
    total_size = int(response.headers.get("content-length", 0))
    block_size = 8192
    data_chunks: list[bytes] = []
    bar: Any
    with click.progressbar(
        length=total_size, label="Downloading", show_pos=True
    ) as bar:
        for chunk in response.iter_bytes(block_size):
            if chunk:
                data_chunks.append(chunk)
                bar.update(len(chunk))
    data_content = b"".join(data_chunks)
    result: dict[str, Any] = json.loads(data_content.decode("utf-8"))
    return result


def get_provider_data() -> dict[str, list[str]] | None:
    """Retrieves provider data from a cache file.

    Filters out models based on provider criteria, and returns a dictionary of providers
    mapped to their models.

    Returns:
        A dictionary of providers mapped to their models or None if the operation fails.
    """
    cache_dir = Path.home() / ".crewai"
    cache_dir.mkdir(exist_ok=True)
    cache_file = cache_dir / "provider_cache.json"
    cache_expiry = 24 * 3600

    data = load_provider_data(cache_file, cache_expiry)
    if not data:
        return None

    provider_models = defaultdict(list)
    for model_name, properties in data.items():
        provider = properties.get("litellm_provider", "").strip().lower()
        if "http" in provider or provider == "other":
            continue
        if provider:
            provider_models[provider].append(model_name)
    return provider_models
