from typing import Any, Dict, List, Optional, Union

import pytest
from crewai import Agent, Crew, Process, Task
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.llm_utils import create_llm


class CustomLLM(BaseLLM):
    """Custom LLM implementation for testing.

    This is a simple implementation of the BaseLLM abstract base class
    that returns a predefined response for testing purposes.
    """

    def __init__(self, response="Default response", model="test-model"):
        """Initialize the CustomLLM with a predefined response.

        Args:
            response: The predefined response to return from call().
        """
        super().__init__(model=model)
        self.response = response
        self.call_count = 0

    def call(
        self,
        messages,
        tools=None,
        callbacks=None,
        available_functions=None,
        from_task=None,
        from_agent=None,
        response_model=None,
    ):
        """
        Mock LLM call that returns a predefined response.
        Properly formats messages to match OpenAI's expected structure.
        """
        self.call_count += 1

        # If input is a string, convert to proper message format
        if isinstance(messages, str):
            messages = [{"role": "user", "content": messages}]

        # Ensure each message has properly formatted content
        for message in messages:
            if isinstance(message["content"], str):
                message["content"] = [{"type": "text", "text": message["content"]}]

        # Return predefined response in expected format
        if "Thought:" in str(messages):
            return f"Thought: I will say hi\nFinal Answer: {self.response}"
        return self.response

    def supports_function_calling(self) -> bool:
        """Return False to indicate that function calling is not supported.

        Returns:
            False, indicating that this LLM does not support function calling.
        """
        return False

    def supports_stop_words(self) -> bool:
        """Return False to indicate that stop words are not supported.

        Returns:
            False, indicating that this LLM does not support stop words.
        """
        return False

    def get_context_window_size(self) -> int:
        """Return a default context window size.

        Returns:
            4096, a typical context window size for modern LLMs.
        """
        return 4096

    async def acall(self, messages, tools=None, callbacks=None, available_functions=None, from_task=None, from_agent=None, response_model=None):
        raise NotImplementedError


@pytest.mark.vcr()
def test_custom_llm_implementation():
    """Test that a custom LLM implementation works with create_llm."""
    custom_llm = CustomLLM(response="The answer is 42")

    # Test that create_llm returns the custom LLM instance directly
    result_llm = create_llm(custom_llm)

    assert result_llm is custom_llm

    # Test calling the custom LLM
    response = result_llm.call(
        "What is the answer to life, the universe, and everything?"
    )

    # Verify that the response from the custom LLM was used
    assert "42" in response


@pytest.mark.vcr()
def test_custom_llm_within_crew():
    """Test that a custom LLM implementation works with create_llm."""
    custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model")

    agent = Agent(
        role="Say Hi",
        goal="Say hi to the user",
        backstory="""You just say hi to the user""",
        llm=custom_llm,
    )

    task = Task(
        description="Say hi to the user",
        expected_output="A greeting to the user",
        agent=agent,
    )

    crew = Crew(
        agents=[agent],
        tasks=[task],
        process=Process.sequential,
    )

    result = crew.kickoff()

    # Assert the LLM was called
    assert custom_llm.call_count > 0
    # Assert we got a response
    assert "Hello!" in result.raw


def test_custom_llm_message_formatting():
    """Test that the custom LLM properly formats messages"""
    custom_llm = CustomLLM(response="Test response", model="test-model")

    # Test with string input
    result = custom_llm.call("Test message")
    assert result == "Test response"

    # Test with message list
    messages = [
        {"role": "system", "content": "System message"},
        {"role": "user", "content": "User message"},
    ]
    result = custom_llm.call(messages)
    assert result == "Test response"


class JWTAuthLLM(BaseLLM):
    """Custom LLM implementation with JWT authentication."""

    def __init__(self, jwt_token: str):
        super().__init__(model="test-model")
        if not jwt_token or not isinstance(jwt_token, str):
            raise ValueError("Invalid JWT token")
        self.jwt_token = jwt_token
        self.calls = []
        self.stop = []

    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
        from_task=None,
        from_agent=None,
        response_model=None,
    ) -> Union[str, Any]:
        """Record the call and return a predefined response."""
        self.calls.append(
            {
                "messages": messages,
                "tools": tools,
                "callbacks": callbacks,
                "available_functions": available_functions,
            }
        )
        # In a real implementation, this would use the JWT token to authenticate
        # with an external service
        return "Response from JWT-authenticated LLM"

    def supports_function_calling(self) -> bool:
        """Return True to indicate that function calling is supported."""
        return True

    def supports_stop_words(self) -> bool:
        """Return True to indicate that stop words are supported."""
        return True

    def get_context_window_size(self) -> int:
        """Return a default context window size."""
        return 8192

    async def acall(self, messages, tools=None, callbacks=None, available_functions=None, from_task=None, from_agent=None, response_model=None):
        raise NotImplementedError


def test_custom_llm_with_jwt_auth():
    """Test a custom LLM implementation with JWT authentication."""
    jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")

    # Test that create_llm returns the JWT-authenticated LLM instance directly
    result_llm = create_llm(jwt_llm)

    assert result_llm is jwt_llm

    # Test calling the JWT-authenticated LLM
    response = result_llm.call("Test message")

    # Verify that the JWT-authenticated LLM was called
    assert len(jwt_llm.calls) > 0
    # Verify that the response from the JWT-authenticated LLM was used
    assert response == "Response from JWT-authenticated LLM"


def test_jwt_auth_llm_validation():
    """Test that JWT token validation works correctly."""
    # Test with invalid JWT token (empty string)
    with pytest.raises(ValueError, match="Invalid JWT token"):
        JWTAuthLLM(jwt_token="")

    # Test with invalid JWT token (non-string)
    with pytest.raises(ValueError, match="Invalid JWT token"):
        JWTAuthLLM(jwt_token=None)


class TimeoutHandlingLLM(BaseLLM):
    """Custom LLM implementation with timeout handling and retry logic."""

    def __init__(self, max_retries: int = 3, timeout: int = 30):
        """Initialize the TimeoutHandlingLLM with retry and timeout settings.

        Args:
            max_retries: Maximum number of retry attempts.
            timeout: Timeout in seconds for each API call.
        """
        super().__init__(model="test-model")
        self.max_retries = max_retries
        self.timeout = timeout
        self.calls = []
        self.stop = []
        self.fail_count = 0  # Number of times to simulate failure

    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
        from_task=None,
        from_agent=None,
        response_model=None,
    ) -> Union[str, Any]:
        """Simulate API calls with timeout handling and retry logic.

        Args:
            messages: Input messages for the LLM.
            tools: Optional list of tool schemas for function calling.
            callbacks: Optional list of callback functions.
            available_functions: Optional dict mapping function names to callables.

        Returns:
            A response string based on whether this is the first attempt or a retry.

        Raises:
            TimeoutError: If all retry attempts fail.
        """
        # Record the initial call
        self.calls.append(
            {
                "messages": messages,
                "tools": tools,
                "callbacks": callbacks,
                "available_functions": available_functions,
                "attempt": 0,
            }
        )

        # Simulate retry logic
        for attempt in range(self.max_retries):
            # Skip the first attempt recording since we already did that above
            if attempt == 0:
                # Simulate a failure if fail_count > 0
                if self.fail_count > 0:
                    self.fail_count -= 1
                    # If we've used all retries, raise an error
                    if attempt == self.max_retries - 1:
                        raise TimeoutError(
                            f"LLM request failed after {self.max_retries} attempts"
                        )
                    # Otherwise, continue to the next attempt (simulating backoff)
                    continue
                # Success on first attempt
                return "First attempt response"
            # This is a retry attempt (attempt > 0)
            # Always record retry attempts
            self.calls.append(
                {
                    "retry_attempt": attempt,
                    "messages": messages,
                    "tools": tools,
                    "callbacks": callbacks,
                    "available_functions": available_functions,
                }
            )

            # Simulate a failure if fail_count > 0
            if self.fail_count > 0:
                self.fail_count -= 1
                # If we've used all retries, raise an error
                if attempt == self.max_retries - 1:
                    raise TimeoutError(
                        f"LLM request failed after {self.max_retries} attempts"
                    )
                # Otherwise, continue to the next attempt (simulating backoff)
                continue
            # Success on retry
            return "Response after retry"

    def supports_function_calling(self) -> bool:
        """Return True to indicate that function calling is supported.

        Returns:
            True, indicating that this LLM supports function calling.
        """
        return True

    def supports_stop_words(self) -> bool:
        """Return True to indicate that stop words are supported.

        Returns:
            True, indicating that this LLM supports stop words.
        """
        return True

    def get_context_window_size(self) -> int:
        """Return a default context window size.

        Returns:
            8192, a typical context window size for modern LLMs.
        """
        return 8192

    async def acall(self, messages, tools=None, callbacks=None, available_functions=None, from_task=None, from_agent=None, response_model=None):
        raise NotImplementedError


def test_timeout_handling_llm():
    """Test a custom LLM implementation with timeout handling and retry logic."""
    # Test successful first attempt
    llm = TimeoutHandlingLLM()
    response = llm.call("Test message")
    assert response == "First attempt response"
    assert len(llm.calls) == 1

    # Test successful retry
    llm = TimeoutHandlingLLM()
    llm.fail_count = 1  # Fail once, then succeed
    response = llm.call("Test message")
    assert response == "Response after retry"
    assert len(llm.calls) == 2  # Initial call + successful retry call

    # Test failure after all retries
    llm = TimeoutHandlingLLM(max_retries=2)
    llm.fail_count = 2  # Fail twice, which is all retries
    with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
        llm.call("Test message")
    assert len(llm.calls) == 2  # Initial call + failed retry attempt
