import json
import os
from typing import Any

from crewai.tools import BaseTool
from dotenv import load_dotenv
from pydantic import BaseModel, Field

from crewai_tools.aws.bedrock.exceptions import (
    BedrockKnowledgeBaseError,
    BedrockValidationError,
)


# Load environment variables from .env file
load_dotenv()


class BedrockKBRetrieverToolInput(BaseModel):
    """Input schema for BedrockKBRetrieverTool."""

    query: str = Field(
        ..., description="The query to retrieve information from the knowledge base"
    )


class BedrockKBRetrieverTool(BaseTool):
    name: str = "Bedrock Knowledge Base Retriever Tool"
    description: str = (
        "Retrieves information from an Amazon Bedrock Knowledge Base given a query"
    )
    args_schema: type[BaseModel] = BedrockKBRetrieverToolInput
    knowledge_base_id: str = None  # type: ignore[assignment]
    number_of_results: int | None = 5
    retrieval_configuration: dict[str, Any] | None = None
    guardrail_configuration: dict[str, Any] | None = None
    next_token: str | None = None
    package_dependencies: list[str] = Field(default_factory=lambda: ["boto3"])

    def __init__(
        self,
        knowledge_base_id: str | None = None,
        number_of_results: int | None = 5,
        retrieval_configuration: dict[str, Any] | None = None,
        guardrail_configuration: dict[str, Any] | None = None,
        next_token: str | None = None,
        **kwargs: Any,
    ) -> None:
        """Initialize the BedrockKBRetrieverTool with knowledge base configuration.

        Args:
            knowledge_base_id: The unique identifier of the knowledge base to query.
            number_of_results: The maximum number of results to return.
            retrieval_configuration: Configurations for the knowledge base query and retrieval process.
            guardrail_configuration: Guardrail settings.
            next_token: Token for retrieving the next batch of results.
        """
        super().__init__(**kwargs)

        # Get knowledge_base_id from environment variable if not provided
        self.knowledge_base_id = knowledge_base_id or os.getenv("BEDROCK_KB_ID")  # type: ignore[assignment]
        self.number_of_results = number_of_results
        self.guardrail_configuration = guardrail_configuration
        self.next_token = next_token

        # Initialize retrieval_configuration with provided parameters or use the one provided
        if retrieval_configuration is None:
            self.retrieval_configuration = self._build_retrieval_configuration()
        else:
            self.retrieval_configuration = retrieval_configuration

        # Validate parameters
        self._validate_parameters()

        # Update the description to include the knowledge base details
        self.description = f"Retrieves information from Amazon Bedrock Knowledge Base '{self.knowledge_base_id}' given a query"

    def _build_retrieval_configuration(self) -> dict[str, Any]:
        """Build the retrieval configuration based on provided parameters.

        Returns:
            Dict[str, Any]: The constructed retrieval configuration
        """
        vector_search_config = {}

        # Add number of results if provided
        if self.number_of_results is not None:
            vector_search_config["numberOfResults"] = self.number_of_results

        return {"vectorSearchConfiguration": vector_search_config}

    def _validate_parameters(self) -> None:
        """Validate the parameters according to AWS API requirements."""
        try:
            # Validate knowledge_base_id
            if not self.knowledge_base_id:
                raise BedrockValidationError("knowledge_base_id cannot be empty")
            if not isinstance(self.knowledge_base_id, str):
                raise BedrockValidationError("knowledge_base_id must be a string")
            if len(self.knowledge_base_id) > 10:
                raise BedrockValidationError(
                    "knowledge_base_id must be 10 characters or less"
                )
            if not all(c.isalnum() for c in self.knowledge_base_id):
                raise BedrockValidationError(
                    "knowledge_base_id must contain only alphanumeric characters"
                )

            # Validate next_token if provided
            if self.next_token:
                if not isinstance(self.next_token, str):
                    raise BedrockValidationError("next_token must be a string")
                if len(self.next_token) < 1 or len(self.next_token) > 2048:
                    raise BedrockValidationError(
                        "next_token must be between 1 and 2048 characters"
                    )
                if " " in self.next_token:
                    raise BedrockValidationError("next_token cannot contain spaces")

            # Validate number_of_results if provided
            if self.number_of_results is not None:
                if not isinstance(self.number_of_results, int):
                    raise BedrockValidationError("number_of_results must be an integer")
                if self.number_of_results < 1:
                    raise BedrockValidationError(
                        "number_of_results must be greater than 0"
                    )

        except BedrockValidationError as e:
            raise BedrockValidationError(f"Parameter validation failed: {e!s}") from e

    def _process_retrieval_result(self, result: dict[str, Any]) -> dict[str, Any]:
        """Process a single retrieval result from Bedrock Knowledge Base.

        Args:
            result (Dict[str, Any]): Raw result from Bedrock Knowledge Base

        Returns:
            Dict[str, Any]: Processed result with standardized format
        """
        # Extract content
        content_obj = result.get("content", {})
        content = content_obj.get("text", "")
        content_type = content_obj.get("type", "text")

        # Extract location information
        location = result.get("location", {})
        location_type = location.get("type", "unknown")
        source_uri = None

        # Map for location types and their URI fields
        location_mapping = {
            "s3Location": {"field": "uri", "type": "S3"},
            "confluenceLocation": {"field": "url", "type": "Confluence"},
            "salesforceLocation": {"field": "url", "type": "Salesforce"},
            "sharePointLocation": {"field": "url", "type": "SharePoint"},
            "webLocation": {"field": "url", "type": "Web"},
            "customDocumentLocation": {"field": "id", "type": "CustomDocument"},
            "kendraDocumentLocation": {"field": "uri", "type": "KendraDocument"},
            "sqlLocation": {"field": "query", "type": "SQL"},
        }

        # Extract the URI based on location type
        for loc_key, config in location_mapping.items():
            if loc_key in location:
                source_uri = location[loc_key].get(config["field"])
                if not location_type or location_type == "unknown":
                    location_type = config["type"]
                break

        # Create result object
        result_object = {
            "content": content,
            "content_type": content_type,
            "source_type": location_type,
            "source_uri": source_uri,
        }

        # Add optional fields if available
        if "score" in result:
            result_object["score"] = result["score"]

        if "metadata" in result:
            result_object["metadata"] = result["metadata"]

        # Handle byte content if present
        if "byteContent" in content_obj:
            result_object["byte_content"] = content_obj["byteContent"]

        # Handle row content if present
        if "row" in content_obj:
            result_object["row_content"] = content_obj["row"]

        return result_object

    def _run(self, query: str) -> str:
        try:
            import boto3
            from botocore.exceptions import ClientError
        except ImportError as e:
            raise ImportError(
                "`boto3` package not found, please run `uv add boto3`"
            ) from e

        try:
            # Initialize the Bedrock Agent Runtime client
            bedrock_agent_runtime = boto3.client(
                "bedrock-agent-runtime",
                region_name=os.getenv(
                    "AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-east-1")
                ),
                # AWS SDK will automatically use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from environment
            )

            # Prepare the request parameters
            retrieve_params = {
                "knowledgeBaseId": self.knowledge_base_id,
                "retrievalQuery": {"text": query},
            }

            # Add optional parameters if provided
            if self.retrieval_configuration:
                retrieve_params["retrievalConfiguration"] = self.retrieval_configuration

            if self.guardrail_configuration:
                retrieve_params["guardrailConfiguration"] = self.guardrail_configuration

            if self.next_token:
                retrieve_params["nextToken"] = self.next_token

            # Make the retrieve API call
            response = bedrock_agent_runtime.retrieve(**retrieve_params)

            # Process the response
            results = []
            for result in response.get("retrievalResults", []):
                processed_result = self._process_retrieval_result(result)
                results.append(processed_result)

            # Build the response object
            response_object = {}
            if results:
                response_object["results"] = results
            else:
                response_object["message"] = "No results found for the given query."  # type: ignore[assignment]

            if "nextToken" in response:
                response_object["nextToken"] = response["nextToken"]

            if "guardrailAction" in response:
                response_object["guardrailAction"] = response["guardrailAction"]

            # Return the results as a JSON string
            return json.dumps(response_object, indent=2)

        except ClientError as e:
            error_code = "Unknown"
            error_message = str(e)

            # Try to extract error code if available
            if hasattr(e, "response") and "Error" in e.response:
                error_code = e.response["Error"].get("Code", "Unknown")
                error_message = e.response["Error"].get("Message", str(e))

            raise BedrockKnowledgeBaseError(
                f"Error ({error_code}): {error_message}"
            ) from e
        except Exception as e:
            raise BedrockKnowledgeBaseError(f"Unexpected error: {e!s}") from e
