"""Tests for event context management."""

import pytest

from crewai.events.event_context import (
    SCOPE_ENDING_EVENTS,
    SCOPE_STARTING_EVENTS,
    VALID_EVENT_PAIRS,
    EmptyStackError,
    EventPairingError,
    MismatchBehavior,
    StackDepthExceededError,
    _event_context_config,
    _event_id_stack,
    EventContextConfig,
    get_current_parent_id,
    get_enclosing_parent_id,
    get_last_event_id,
    get_triggering_event_id,
    handle_empty_pop,
    handle_mismatch,
    pop_event_scope,
    push_event_scope,
    reset_last_event_id,
    resume_task_scope,
    set_last_event_id,
    set_triggering_event_id,
    triggered_by_scope,
)


class TestStackOperations:
    """Tests for stack push/pop operations."""

    def test_empty_stack_returns_none(self) -> None:
        assert get_current_parent_id() is None
        assert get_enclosing_parent_id() is None

    def test_push_and_get_parent(self) -> None:
        push_event_scope("event-1", "task_started")
        assert get_current_parent_id() == "event-1"

    def test_nested_push(self) -> None:
        push_event_scope("event-1", "crew_kickoff_started")
        push_event_scope("event-2", "task_started")
        assert get_current_parent_id() == "event-2"
        assert get_enclosing_parent_id() == "event-1"

    def test_pop_restores_parent(self) -> None:
        push_event_scope("event-1", "crew_kickoff_started")
        push_event_scope("event-2", "task_started")
        popped = pop_event_scope()
        assert popped == ("event-2", "task_started")
        assert get_current_parent_id() == "event-1"

    def test_pop_empty_stack_returns_none(self) -> None:
        assert pop_event_scope() is None


class TestStackDepthLimit:
    """Tests for stack depth limit."""

    def test_depth_limit_exceeded_raises(self) -> None:
        _event_context_config.set(EventContextConfig(max_stack_depth=3))

        push_event_scope("event-1", "type-1")
        push_event_scope("event-2", "type-2")
        push_event_scope("event-3", "type-3")

        with pytest.raises(StackDepthExceededError):
            push_event_scope("event-4", "type-4")


class TestMismatchHandling:
    """Tests for mismatch behavior."""

    def test_handle_mismatch_raises_when_configured(self) -> None:
        _event_context_config.set(
            EventContextConfig(mismatch_behavior=MismatchBehavior.RAISE)
        )

        with pytest.raises(EventPairingError):
            handle_mismatch("task_completed", "llm_call_started", "task_started")

    def test_handle_empty_pop_raises_when_configured(self) -> None:
        _event_context_config.set(
            EventContextConfig(empty_pop_behavior=MismatchBehavior.RAISE)
        )

        with pytest.raises(EmptyStackError):
            handle_empty_pop("task_completed")


class TestEventTypeSets:
    """Tests for event type set completeness."""

    def test_all_ending_events_have_pairs(self) -> None:
        for ending_event in SCOPE_ENDING_EVENTS:
            assert ending_event in VALID_EVENT_PAIRS

    def test_all_pairs_reference_starting_events(self) -> None:
        for ending_event, starting_event in VALID_EVENT_PAIRS.items():
            assert starting_event in SCOPE_STARTING_EVENTS

    def test_starting_and_ending_are_disjoint(self) -> None:
        overlap = SCOPE_STARTING_EVENTS & SCOPE_ENDING_EVENTS
        assert not overlap


class TestLastEventIdTracking:
    """Tests for linear chain event ID tracking."""

    def test_initial_last_event_id_is_none(self) -> None:
        reset_last_event_id()
        assert get_last_event_id() is None

    def test_set_and_get_last_event_id(self) -> None:
        reset_last_event_id()
        set_last_event_id("event-123")
        assert get_last_event_id() == "event-123"

    def test_reset_clears_last_event_id(self) -> None:
        set_last_event_id("event-123")
        reset_last_event_id()
        assert get_last_event_id() is None

    def test_overwrite_last_event_id(self) -> None:
        reset_last_event_id()
        set_last_event_id("event-1")
        set_last_event_id("event-2")
        assert get_last_event_id() == "event-2"


class TestTriggeringEventIdTracking:
    """Tests for causal chain event ID tracking."""

    def test_initial_triggering_event_id_is_none(self) -> None:
        set_triggering_event_id(None)
        assert get_triggering_event_id() is None

    def test_set_and_get_triggering_event_id(self) -> None:
        set_triggering_event_id("trigger-123")
        assert get_triggering_event_id() == "trigger-123"
        set_triggering_event_id(None)

    def test_set_none_clears_triggering_event_id(self) -> None:
        set_triggering_event_id("trigger-123")
        set_triggering_event_id(None)
        assert get_triggering_event_id() is None


class TestTriggeredByScope:
    """Tests for triggered_by_scope context manager."""

    def test_scope_sets_triggering_id(self) -> None:
        set_triggering_event_id(None)
        with triggered_by_scope("trigger-456"):
            assert get_triggering_event_id() == "trigger-456"

    def test_scope_restores_previous_value(self) -> None:
        set_triggering_event_id(None)
        with triggered_by_scope("trigger-456"):
            pass
        assert get_triggering_event_id() is None

    def test_nested_scopes(self) -> None:
        set_triggering_event_id(None)
        with triggered_by_scope("outer"):
            assert get_triggering_event_id() == "outer"
            with triggered_by_scope("inner"):
                assert get_triggering_event_id() == "inner"
            assert get_triggering_event_id() == "outer"
        assert get_triggering_event_id() is None

    def test_scope_restores_on_exception(self) -> None:
        set_triggering_event_id(None)
        try:
            with triggered_by_scope("trigger-789"):
                raise ValueError("test error")
        except ValueError:
            pass
        assert get_triggering_event_id() is None


class TestResumeTaskScope:
    """Tests for the checkpoint-resume scope helper."""

    @pytest.fixture(autouse=True)
    def _reset_stack(self) -> None:
        _event_id_stack.set(())

    def _bind_runtime_state(self, *event_dicts: dict[str, object]):
        from crewai.events import crewai_event_bus
        from crewai.events.types.task_events import TaskStartedEvent
        from crewai.state.event_record import EventRecord
        from crewai.state.runtime import RuntimeState

        record = EventRecord()
        for spec in event_dicts:
            ev = TaskStartedEvent(context=None, task=None)
            ev.task_id = spec["task_id"]  # type: ignore[assignment]
            ev.event_id = spec["event_id"]  # type: ignore[assignment]
            ev.emission_sequence = spec["emission_sequence"]  # type: ignore[assignment]
            record.add(ev)
        state = RuntimeState(root=[])
        state._event_record = record

        previous = crewai_event_bus._runtime_state
        crewai_event_bus._runtime_state = state
        return crewai_event_bus, previous

    def test_returns_false_when_no_runtime_state(self) -> None:
        from crewai.events import crewai_event_bus

        previous = crewai_event_bus._runtime_state
        crewai_event_bus._runtime_state = None
        try:
            assert resume_task_scope("any-task") is False
            assert _event_id_stack.get() == ()
        finally:
            crewai_event_bus._runtime_state = previous

    def test_returns_false_when_no_matching_event(self) -> None:
        bus, previous = self._bind_runtime_state(
            {"task_id": "other", "event_id": "e1", "emission_sequence": 1},
        )
        try:
            assert resume_task_scope("missing") is False
            assert _event_id_stack.get() == ()
        finally:
            bus._runtime_state = previous

    def test_pushes_latest_event_for_task(self) -> None:
        bus, previous = self._bind_runtime_state(
            {"task_id": "t1", "event_id": "e1", "emission_sequence": 1},
            {"task_id": "t1", "event_id": "e2", "emission_sequence": 5},
            {"task_id": "t1", "event_id": "e3", "emission_sequence": 3},
            {"task_id": "t2", "event_id": "x1", "emission_sequence": 9},
        )
        try:
            assert resume_task_scope("t1") is True
            assert _event_id_stack.get() == (("e2", "task_started"),)
        finally:
            bus._runtime_state = previous

    def test_pairs_cleanly_with_task_completed(self) -> None:
        """The pushed scope must be popped by a matching task_completed."""
        from crewai.events import crewai_event_bus
        from crewai.events.types.task_events import TaskCompletedEvent
        from crewai.tasks.task_output import TaskOutput

        push_event_scope("kickoff-1", "crew_kickoff_started")
        bus, previous = self._bind_runtime_state(
            {"task_id": "t1", "event_id": "started-1", "emission_sequence": 1},
        )
        try:
            assert resume_task_scope("t1") is True
            output = TaskOutput(description="d", raw="r", agent="a")
            completed = TaskCompletedEvent(output=output, task=None)
            completed.task_id = "t1"
            crewai_event_bus.emit(None, completed)
            crewai_event_bus.flush()
            assert _event_id_stack.get() == (("kickoff-1", "crew_kickoff_started"),)
            assert completed.started_event_id == "started-1"
        finally:
            bus._runtime_state = previous
            _event_id_stack.set(())


def test_agent_scope_preserved_after_tool_error_event() -> None:
    from crewai.events import crewai_event_bus
    from crewai.events.types.tool_usage_events import (
        ToolUsageErrorEvent,
        ToolUsageStartedEvent,
    )

    push_event_scope("crew-1", "crew_kickoff_started")
    push_event_scope("task-1", "task_started")
    push_event_scope("agent-1", "agent_execution_started")

    crewai_event_bus.emit(
        None,
        ToolUsageStartedEvent(
            tool_name="test_tool",
            tool_args={},
            agent_key="test_agent",
        )
    )

    crewai_event_bus.emit(
        None,
        ToolUsageErrorEvent(
            tool_name="test_tool",
            tool_args={},
            agent_key="test_agent",
            error=ValueError("test error"),
        )
    )

    crewai_event_bus.flush()

    assert get_current_parent_id() == "agent-1"

