"""Tests for unified memory: types, storage, Memory, MemoryScope, MemorySlice, Flow integration."""

from __future__ import annotations

from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock

import pytest

from crewai_core.printer import Printer
from crewai.memory.types import (
    MemoryConfig,
    MemoryMatch,
    MemoryRecord,
    ScopeInfo,
    compute_composite_score,
)


# --- Types ---


def test_memory_record_defaults() -> None:
    r = MemoryRecord(content="hello")
    assert r.content == "hello"
    assert r.scope == "/"
    assert r.categories == []
    assert r.importance == 0.5
    assert r.embedding is None
    assert r.id is not None
    assert isinstance(r.created_at, datetime)


def test_memory_match() -> None:
    r = MemoryRecord(content="x", scope="/a")
    m = MemoryMatch(record=r, score=0.9, match_reasons=["semantic"])
    assert m.record.content == "x"
    assert m.score == 0.9
    assert m.match_reasons == ["semantic"]


def test_memory_record_embedding_excluded_from_serialization() -> None:
    """Embedding vectors should not appear in serialized output to save tokens."""
    r = MemoryRecord(content="hello", embedding=[0.1, 0.2, 0.3])

    # Direct access still works
    assert r.embedding == [0.1, 0.2, 0.3]

    # model_dump excludes embedding by default
    dumped = r.model_dump()
    assert "embedding" not in dumped
    assert dumped["content"] == "hello"
    json_str = r.model_dump_json()
    assert "embedding" not in json_str
    rehydrated = MemoryRecord.model_validate_json(json_str)
    assert rehydrated.embedding is None

    # repr excludes embedding
    assert "embedding=" not in repr(r)

    # Direct attribute access still works for storage layer
    assert r.embedding is not None
    assert len(r.embedding) == 3


def test_memory_match_embedding_excluded_from_serialization() -> None:
    """MemoryMatch serialization should not leak embedding vectors."""
    r = MemoryRecord(content="x", embedding=[0.5] * 1536)
    m = MemoryMatch(record=r, score=0.9, match_reasons=["semantic"])

    dumped = m.model_dump()
    assert "embedding" not in dumped["record"]
    assert dumped["record"]["content"] == "x"
    assert dumped["score"] == 0.9


def test_scope_info() -> None:
    i = ScopeInfo(path="/", record_count=5, categories=["c1"], child_scopes=["/a"])
    assert i.path == "/"
    assert i.record_count == 5
    assert i.categories == ["c1"]
    assert i.child_scopes == ["/a"]


def test_memory_config() -> None:
    c = MemoryConfig()
    assert c.recency_weight == 0.3
    assert c.semantic_weight == 0.5
    assert c.importance_weight == 0.2


# --- LanceDB storage ---


@pytest.fixture
def lancedb_path(tmp_path: Path) -> Path:
    return tmp_path / "mem"


def test_lancedb_save_search(lancedb_path: Path) -> None:
    from crewai.memory.storage.lancedb_storage import LanceDBStorage

    storage = LanceDBStorage(path=str(lancedb_path), vector_dim=4)
    r = MemoryRecord(
        content="test content",
        scope="/foo",
        categories=["cat1"],
        importance=0.8,
        embedding=[0.1, 0.2, 0.3, 0.4],
    )
    storage.save([r])
    results = storage.search(
        [0.1, 0.2, 0.3, 0.4],
        scope_prefix="/foo",
        limit=5,
    )
    assert len(results) == 1
    rec, score = results[0]
    assert rec.content == "test content"
    assert rec.scope == "/foo"
    assert score >= 0.0


def test_lancedb_delete_count(lancedb_path: Path) -> None:
    from crewai.memory.storage.lancedb_storage import LanceDBStorage

    storage = LanceDBStorage(path=str(lancedb_path), vector_dim=4)
    r = MemoryRecord(content="x", scope="/", embedding=[0.0] * 4)
    storage.save([r])
    assert storage.count() == 1
    n = storage.delete(scope_prefix="/")
    assert n >= 1
    assert storage.count() == 0


def test_lancedb_list_scopes_get_scope_info(lancedb_path: Path) -> None:
    from crewai.memory.storage.lancedb_storage import LanceDBStorage

    storage = LanceDBStorage(path=str(lancedb_path), vector_dim=4)
    storage.save([
        MemoryRecord(content="a", scope="/", embedding=[0.0] * 4),
        MemoryRecord(content="b", scope="/team", embedding=[0.0] * 4),
    ])
    scopes = storage.list_scopes("/")
    assert "/team" in scopes  # list_scopes returns children, not root itself
    info = storage.get_scope_info("/")
    assert info.record_count >= 1
    assert info.path == "/"


# --- Memory class (with mock embedder, no LLM for explicit remember) ---


@pytest.fixture
def mock_embedder() -> MagicMock:
    """Embedder mock that returns one embedding per input text (batch-aware)."""
    m = MagicMock()
    m.side_effect = lambda texts: [[0.1] * 1536 for _ in texts]
    return m


@pytest.fixture
def memory_with_storage(tmp_path: Path, mock_embedder: MagicMock) -> None:
    import os
    os.environ.pop("OPENAI_API_KEY", None)


def test_memory_remember_recall_shallow(tmp_path: Path, mock_embedder: MagicMock) -> None:
    from crewai.memory.unified_memory import Memory

    m = Memory(
        storage=str(tmp_path / "db"),
        llm=MagicMock(),
        embedder=mock_embedder,
    )
    # Explicit scope/categories/importance so no LLM analysis
    r = m.remember(
        "We decided to use Python.",
        scope="/project",
        categories=["decision"],
        importance=0.7,
    )
    assert r.content == "We decided to use Python."
    assert r.scope == "/project"

    matches = m.recall("Python decision", scope="/project", limit=5, depth="shallow")
    assert len(matches) >= 1
    assert "Python" in matches[0].record.content or "python" in matches[0].record.content.lower()


def test_memory_forget(tmp_path: Path, mock_embedder: MagicMock) -> None:
    from crewai.memory.unified_memory import Memory

    m = Memory(storage=str(tmp_path / "db2"), llm=MagicMock(), embedder=mock_embedder)
    m.remember("To forget", scope="/x", categories=[], importance=0.5, metadata={})
    assert m._storage.count("/x") >= 1
    n = m.forget(scope="/x")
    assert n >= 1
    assert m._storage.count("/x") == 0


def test_memory_scope_slice(tmp_path: Path, mock_embedder: MagicMock) -> None:
    from crewai.memory.unified_memory import Memory

    mem = Memory(storage=str(tmp_path / "db3"), llm=MagicMock(), embedder=mock_embedder)
    sc = mem.scope("/agent/1")
    assert sc._root in ("/agent/1", "/agent/1/")
    sl = mem.slice(["/a", "/b"], read_only=True)
    assert sl.read_only is True
    assert "/a" in sl.scopes and "/b" in sl.scopes


def test_memory_list_scopes_info_tree(tmp_path: Path, mock_embedder: MagicMock) -> None:
    from crewai.memory.unified_memory import Memory

    m = Memory(storage=str(tmp_path / "db4"), llm=MagicMock(), embedder=mock_embedder)
    m.remember("Root", scope="/", categories=[], importance=0.5, metadata={})
    m.remember("Team note", scope="/team", categories=[], importance=0.5, metadata={})
    scopes = m.list_scopes("/")
    assert "/team" in scopes  # list_scopes returns children, not root itself
    info = m.info("/")
    assert info.record_count >= 1
    tree = m.tree("/", max_depth=2)
    assert "/" in tree or "0 records" in tree or "1 records" in tree


# --- MemoryScope ---


def test_memory_scope_remember_recall(tmp_path: Path, mock_embedder: MagicMock) -> None:
    from crewai.memory.unified_memory import Memory
    from crewai.memory.memory_scope import MemoryScope

    mem = Memory(storage=str(tmp_path / "db5"), llm=MagicMock(), embedder=mock_embedder)
    scope = MemoryScope(memory=mem, root_path="/crew/1")
    scope.remember("Scoped note", scope="/", categories=[], importance=0.5, metadata={})
    results = scope.recall("note", limit=5, depth="shallow")
    assert len(results) >= 1


# --- MemorySlice recall (read-only) ---


def test_memory_slice_recall(tmp_path: Path, mock_embedder: MagicMock) -> None:
    from crewai.memory.unified_memory import Memory
    from crewai.memory.memory_scope import MemorySlice

    mem = Memory(storage=str(tmp_path / "db6"), llm=MagicMock(), embedder=mock_embedder)
    mem.remember("In scope A", scope="/a", categories=[], importance=0.5, metadata={})
    sl = MemorySlice(memory=mem, scopes=["/a"], read_only=True)
    matches = sl.recall("scope", limit=5, depth="shallow")
    assert isinstance(matches, list)


def test_memory_slice_remember_is_noop_when_read_only(tmp_path: Path, mock_embedder: MagicMock) -> None:
    from crewai.memory.unified_memory import Memory
    from crewai.memory.memory_scope import MemorySlice

    mem = Memory(storage=str(tmp_path / "db7"), llm=MagicMock(), embedder=mock_embedder)
    sl = MemorySlice(memory=mem, scopes=["/a"], read_only=True)
    result = sl.remember("x", scope="/a")
    assert result is None
    assert mem.list_records() == []


# --- Flow memory ---


def test_flow_has_default_memory() -> None:
    """Flow auto-creates a Memory instance when none is provided."""
    from crewai.flow.flow import Flow
    from crewai.memory.unified_memory import Memory

    class DefaultFlow(Flow):
        pass

    f = DefaultFlow()
    assert f.memory is not None
    assert isinstance(f.memory, Memory)


def test_flow_recall_remember_raise_when_memory_explicitly_none() -> None:
    """Flow raises ValueError when memory is explicitly set to None."""
    from crewai.flow.flow import Flow

    class NoMemoryFlow(Flow):
        memory = None

    f = NoMemoryFlow()
    # Explicitly set to None after __init__ auto-creates
    f.memory = None
    with pytest.raises(ValueError, match="No memory configured"):
        f.recall("query")
    with pytest.raises(ValueError, match="No memory configured"):
        f.remember("content")


def test_flow_recall_remember_with_memory(tmp_path: Path, mock_embedder: MagicMock) -> None:
    from crewai.flow.flow import Flow
    from crewai.memory.unified_memory import Memory

    mem = Memory(storage=str(tmp_path / "flow_db"), llm=MagicMock(), embedder=mock_embedder)

    class FlowWithMemory(Flow):
        memory = mem

    f = FlowWithMemory()
    f.remember("Flow remembered this", scope="/flow", categories=[], importance=0.6, metadata={})
    results = f.recall("remembered", limit=5, depth="shallow")
    assert len(results) >= 1


# --- extract_memories ---


def test_memory_extract_memories_returns_list_from_llm(tmp_path: Path) -> None:
    """Memory.extract_memories() delegates to LLM and returns list of strings."""
    from crewai.memory.analyze import ExtractedMemories
    from crewai.memory.unified_memory import Memory

    mock_llm = MagicMock()
    mock_llm.supports_function_calling.return_value = True
    mock_llm.call.return_value = ExtractedMemories(
        memories=["We use Python for the backend.", "API rate limit is 100/min."]
    )

    mem = Memory(
        storage=str(tmp_path / "extract_db"),
        llm=mock_llm,
        embedder=MagicMock(return_value=[[0.1] * 1536]),
    )
    result = mem.extract_memories("Task: Build API. Result: We used Python and set rate limit 100/min.")
    assert result == ["We use Python for the backend.", "API rate limit is 100/min."]
    mock_llm.call.assert_called_once()
    call_kw = mock_llm.call.call_args[1]
    assert call_kw.get("response_model") == ExtractedMemories


def test_memory_extract_memories_empty_content_returns_empty_list(tmp_path: Path) -> None:
    """Memory.extract_memories() with empty/whitespace content returns [] without calling LLM."""
    from crewai.memory.unified_memory import Memory

    mock_llm = MagicMock()
    mem = Memory(storage=str(tmp_path / "empty_db"), llm=mock_llm, embedder=MagicMock())
    assert mem.extract_memories("") == []
    assert mem.extract_memories("   \n  ") == []
    mock_llm.call.assert_not_called()


def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
    """_save_to_memory calls memory.extract_memories(raw) then memory.remember(m) for each."""
    from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
    from crewai.agents.parser import AgentFinish

    mock_memory = MagicMock()
    mock_memory.read_only = False
    mock_memory.extract_memories.return_value = ["Fact A.", "Fact B."]

    mock_agent = MagicMock()
    mock_agent.memory = mock_memory
    mock_agent._logger = MagicMock()
    mock_agent.role = "Researcher"

    mock_task = MagicMock()
    mock_task.description = "Do research"
    mock_task.expected_output = "A report"

    executor = BaseAgentExecutor()
    executor.agent = mock_agent
    executor.task = mock_task
    executor._save_to_memory(
        AgentFinish(thought="", output="We found X and Y.", text="We found X and Y.")
    )

    raw_expected = "Task: Do research\nAgent: Researcher\nExpected result: A report\nResult: We found X and Y."
    mock_memory.extract_memories.assert_called_once_with(raw_expected)
    mock_memory.remember_many.assert_called_once()
    saved_contents = mock_memory.remember_many.call_args.args[0]
    assert saved_contents == ["Fact A.", "Fact B."]


def test_executor_save_to_memory_skips_delegation_output() -> None:
    """_save_to_memory does nothing when output contains delegate action."""
    from crewai.agents.agent_builder.base_agent_executor import BaseAgentExecutor
    from crewai.agents.parser import AgentFinish
    from crewai.utilities.string_utils import sanitize_tool_name

    mock_memory = MagicMock()
    mock_memory.read_only = False
    mock_agent = MagicMock()
    mock_agent.memory = mock_memory
    mock_agent._logger = MagicMock()
    mock_task = MagicMock()
    mock_task.description = "Task"
    mock_task.expected_output = "Out"

    delegate_text = f"Action: {sanitize_tool_name('Delegate work to coworker')}"
    full_text = delegate_text + " rest"
    executor = BaseAgentExecutor()
    executor.agent = mock_agent
    executor.task = mock_task
    executor._save_to_memory(
        AgentFinish(thought="", output=full_text, text=full_text)
    )

    mock_memory.extract_memories.assert_not_called()
    mock_memory.remember.assert_not_called()


def test_memory_scope_extract_memories_delegates() -> None:
    """MemoryScope.extract_memories delegates to underlying Memory."""
    from crewai.memory.memory_scope import MemoryScope

    mock_memory = MagicMock()
    mock_memory.extract_memories.return_value = ["Scoped fact."]
    scope = MemoryScope(memory=mock_memory, root_path="/agent/1")
    result = scope.extract_memories("Some content")
    mock_memory.extract_memories.assert_called_once_with("Some content")
    assert result == ["Scoped fact."]


def test_memory_slice_extract_memories_delegates() -> None:
    """MemorySlice.extract_memories delegates to underlying Memory."""
    from crewai.memory.memory_scope import MemorySlice

    mock_memory = MagicMock()
    mock_memory.extract_memories.return_value = ["Sliced fact."]
    sl = MemorySlice(memory=mock_memory, scopes=["/a", "/b"], read_only=True)
    result = sl.extract_memories("Some content")
    mock_memory.extract_memories.assert_called_once_with("Some content")
    assert result == ["Sliced fact."]


def test_flow_extract_memories_raises_when_memory_explicitly_none() -> None:
    """Flow.extract_memories raises ValueError when memory is explicitly set to None."""
    from crewai.flow.flow import Flow

    f = Flow()
    f.memory = None
    with pytest.raises(ValueError, match="No memory configured"):
        f.extract_memories("some content")


def test_flow_extract_memories_delegates_when_memory_present() -> None:
    """Flow.extract_memories delegates to flow memory and returns list."""
    from crewai.flow.flow import Flow

    mock_memory = MagicMock()
    mock_memory.extract_memories.return_value = ["Flow fact 1.", "Flow fact 2."]

    class FlowWithMemory(Flow):
        memory = mock_memory

    f = FlowWithMemory()
    result = f.extract_memories("content here")
    mock_memory.extract_memories.assert_called_once_with("content here")
    assert result == ["Flow fact 1.", "Flow fact 2."]


# --- Composite scoring ---


def test_composite_score_brand_new_memory() -> None:
    """Brand-new memory has decay ~ 1.0; composite = 0.5*0.8 + 0.3*1.0 + 0.2*0.7 = 0.84."""
    config = MemoryConfig()
    record = MemoryRecord(
        content="test",
        scope="/",
        importance=0.7,
        created_at=datetime.utcnow(),
    )
    score, reasons = compute_composite_score(record, 0.8, config)
    assert 0.82 <= score <= 0.86
    assert "semantic" in reasons
    assert "recency" in reasons
    assert "importance" in reasons


def test_composite_score_old_memory_decayed() -> None:
    """Memory 60 days old (2 half-lives) has decay = 0.25; composite ~ 0.575."""
    config = MemoryConfig(recency_half_life_days=30)
    old_date = datetime.utcnow() - timedelta(days=60)
    record = MemoryRecord(
        content="old",
        scope="/",
        importance=0.5,
        created_at=old_date,
    )
    score, reasons = compute_composite_score(record, 0.8, config)
    assert 0.55 <= score <= 0.60
    assert "semantic" in reasons
    assert "recency" not in reasons  # decay 0.25 is not > 0.5


def test_composite_score_reranks_results(
    tmp_path: Path, mock_embedder: MagicMock
) -> None:
    """Same semantic score: high-importance recent memory ranks first."""
    from crewai.memory.unified_memory import Memory

    # Use same dim as default LanceDB (1536) so storage does not overwrite embedding
    emb = [0.1] * 1536
    mem = Memory(
        storage=str(tmp_path / "rerank_db"),
        llm=MagicMock(),
        embedder=MagicMock(return_value=[emb]),
    )
    # Save both records directly to storage (bypass encoding flow)
    # to test composite scoring in isolation without consolidation merging them.
    record_high = MemoryRecord(
        content="Important decision",
        scope="/",
        categories=[],
        importance=1.0,
        embedding=emb,
    )
    mem._storage.save([record_high])
    old = datetime.utcnow() - timedelta(days=90)
    record_low = MemoryRecord(
        content="Old trivial note",
        scope="/",
        importance=0.1,
        created_at=old,
        embedding=emb,
    )
    mem._storage.save([record_low])

    matches = mem.recall("decision", scope="/", limit=5, depth="shallow")
    assert len(matches) >= 2
    # Top result should be the high-importance recent one (stored via remember)
    assert "Important" in matches[0].record.content or "important" in matches[0].record.content.lower()


def test_composite_score_match_reasons_populated() -> None:
    """match_reasons includes recency for fresh, importance for high-importance; omits for old/low."""
    config = MemoryConfig()
    fresh_high = MemoryRecord(
        content="x",
        importance=0.9,
        created_at=datetime.utcnow(),
    )
    score1, reasons1 = compute_composite_score(fresh_high, 0.5, config)
    assert "semantic" in reasons1
    assert "recency" in reasons1
    assert "importance" in reasons1

    old_low = MemoryRecord(
        content="y",
        importance=0.1,
        created_at=datetime.utcnow() - timedelta(days=60),
    )
    score2, reasons2 = compute_composite_score(old_low, 0.5, config)
    assert "semantic" in reasons2
    assert "recency" not in reasons2
    assert "importance" not in reasons2


def test_composite_score_custom_config() -> None:
    """Zero recency/importance weights => composite equals semantic score."""
    config = MemoryConfig(
        recency_weight=0.0,
        semantic_weight=1.0,
        importance_weight=0.0,
    )
    record = MemoryRecord(
        content="any",
        importance=0.9,
        created_at=datetime.utcnow(),
    )
    score, reasons = compute_composite_score(record, 0.73, config)
    assert score == pytest.approx(0.73, rel=1e-5)
    assert "semantic" in reasons


# --- LLM fallback ---


def test_analyze_for_save_llm_failure_returns_defaults() -> None:
    """When LLM raises, analyze_for_save returns safe defaults."""
    from crewai.memory.analyze import MemoryAnalysis, analyze_for_save

    llm = MagicMock()
    llm.supports_function_calling.return_value = False
    llm.call.side_effect = RuntimeError("API rate limit")
    result = analyze_for_save(
        "some content",
        existing_scopes=["/", "/project"],
        existing_categories=["cat1"],
        llm=llm,
    )
    assert isinstance(result, MemoryAnalysis)
    assert result.suggested_scope == "/"
    assert result.categories == []
    assert result.importance == 0.5
    assert result.extracted_metadata.entities == []
    assert result.extracted_metadata.dates == []
    assert result.extracted_metadata.topics == []


def test_extract_memories_llm_failure_returns_raw() -> None:
    """When LLM raises, extract_memories_from_content returns [content]."""
    from crewai.memory.analyze import extract_memories_from_content

    llm = MagicMock()
    llm.call.side_effect = RuntimeError("Network error")
    content = "Task result: We chose PostgreSQL."
    result = extract_memories_from_content(content, llm)
    assert result == [content]


def test_analyze_query_llm_failure_returns_defaults() -> None:
    """When LLM raises, analyze_query returns safe defaults with available scopes."""
    from crewai.memory.analyze import QueryAnalysis, analyze_query

    llm = MagicMock()
    llm.call.side_effect = RuntimeError("Timeout")
    result = analyze_query(
        "what did we decide?",
        available_scopes=["/", "/project", "/team", "/company", "/other", "/extra"],
        scope_info=None,
        llm=llm,
    )
    assert isinstance(result, QueryAnalysis)
    assert result.keywords == []
    assert result.complexity == "simple"
    assert result.suggested_scopes == ["/", "/project", "/team", "/company", "/other"]


def test_remember_survives_llm_failure(
    tmp_path: Path, mock_embedder: MagicMock
) -> None:
    """When the LLM raises during parallel_analyze, remember() still saves with defaults."""
    from crewai.memory.unified_memory import Memory

    llm = MagicMock()
    llm.call.side_effect = RuntimeError("LLM unavailable")
    mem = Memory(
        storage=str(tmp_path / "fallback_db"),
        llm=llm,
        embedder=mock_embedder,
    )
    record = mem.remember("We decided to use PostgreSQL.")
    assert record.content == "We decided to use PostgreSQL."
    assert record.scope == "/"
    assert record.categories == []
    assert record.importance == 0.5
    assert record.id is not None
    assert mem._storage.count() == 1


# --- Agent.kickoff() memory integration ---


def test_agent_kickoff_memory_recall_and_save(tmp_path: Path, mock_embedder: MagicMock) -> None:
    """Agent.kickoff() with memory should recall before execution and save after."""
    from unittest.mock import Mock, patch

    from crewai.agent.core import Agent
    from crewai.llm import LLM
    from crewai.memory.unified_memory import Memory
    from crewai.types.usage_metrics import UsageMetrics

    # Create a real memory with mock embedder
    mem = Memory(
        storage=str(tmp_path / "agent_kickoff_db"),
        llm=MagicMock(),
        embedder=mock_embedder,
    )

    # Pre-populate a memory record
    mem.remember("The team uses PostgreSQL.", scope="/", categories=["database"], importance=0.8)

    # Create mock LLM for the agent
    mock_llm = Mock(spec=LLM)
    mock_llm.call.return_value = "Final Answer: PostgreSQL is the database."
    mock_llm.stop = []
    mock_llm.supports_stop_words.return_value = False
    mock_llm.supports_function_calling.return_value = False
    mock_llm.get_token_usage_summary.return_value = UsageMetrics(
        total_tokens=10, prompt_tokens=5, completion_tokens=5,
        cached_prompt_tokens=0, successful_requests=1,
    )

    agent = Agent(
        role="Tester",
        goal="Test memory integration",
        backstory="You test things.",
        llm=mock_llm,
        memory=mem,
        verbose=False,
    )

    # Patch on the class to avoid Pydantic BaseModel __delattr__ restriction
    with patch.object(Memory, "recall", wraps=mem.recall) as recall_mock, \
         patch.object(Memory, "extract_memories", return_value=["PostgreSQL is used."]) as extract_mock, \
         patch.object(Memory, "remember_many", wraps=mem.remember_many) as remember_many_mock:
        result = agent.kickoff("What database do we use?")

    assert result is not None
    assert result.raw is not None

    # Verify recall was called (passive memory injection)
    recall_mock.assert_called_once()

    # Verify extract_memories and remember_many were called (passive batch save)
    extract_mock.assert_called_once()
    raw_content = extract_mock.call_args.args[0]
    assert "Input:" in raw_content
    assert "Agent:" in raw_content
    assert "Result:" in raw_content

    # remember_many was called with the extracted memories
    remember_many_mock.assert_called_once()
    saved_contents = remember_many_mock.call_args.args[0]
    assert "PostgreSQL is used." in saved_contents


# --- Batch EncodingFlow tests ---


def test_batch_embed_single_call(tmp_path: Path) -> None:
    """remember_many with 3 items should call the embedder exactly once with all 3 texts."""
    from crewai.memory.unified_memory import Memory

    embedder = MagicMock()
    embedder.side_effect = lambda texts: [[0.1] * 1536 for _ in texts]

    llm = MagicMock()
    llm.supports_function_calling.return_value = False
    mem = Memory(storage=str(tmp_path / "db"), llm=llm, embedder=embedder)

    mem.remember_many(
        ["Fact A.", "Fact B.", "Fact C."],
        scope="/test",
        categories=["test"],
        importance=0.5,
    )
    mem.drain_writes()  # wait for background save
    # The embedder should have been called exactly once with all 3 texts
    embedder.assert_called_once()
    texts_arg = embedder.call_args.args[0]
    assert len(texts_arg) == 3
    assert texts_arg == ["Fact A.", "Fact B.", "Fact C."]


def test_intra_batch_dedup_drops_near_identical(tmp_path: Path) -> None:
    """remember_many with 3 identical strings should store only 1 record."""
    from crewai.memory.unified_memory import Memory

    embedder = MagicMock()
    # All identical embeddings -> cosine similarity = 1.0
    embedder.side_effect = lambda texts: [[0.5] * 1536 for _ in texts]

    llm = MagicMock()
    llm.supports_function_calling.return_value = False
    mem = Memory(storage=str(tmp_path / "db"), llm=llm, embedder=embedder)

    mem.remember_many(
        [
            "CrewAI ensures reliable operation.",
            "CrewAI ensures reliable operation.",
            "CrewAI ensures reliable operation.",
        ],
        scope="/test",
        categories=["reliability"],
        importance=0.7,
    )
    mem.drain_writes()  # wait for background save
    assert mem._storage.count() == 1


def test_intra_batch_dedup_keeps_merely_similar(tmp_path: Path) -> None:
    """remember_many with distinct items should keep all of them."""
    from crewai.memory.unified_memory import Memory
    import math

    # Return different embeddings for different texts
    call_count = 0

    def varying_embedder(texts: list[str]) -> list[list[float]]:
        nonlocal call_count
        result = []
        for i, _ in enumerate(texts):
            # Create orthogonal-ish embeddings so similarity is low
            emb = [0.0] * 1536
            idx = (call_count + i) % 1536
            emb[idx] = 1.0
            result.append(emb)
        call_count += len(texts)
        return result

    embedder = MagicMock(side_effect=varying_embedder)
    llm = MagicMock()
    llm.supports_function_calling.return_value = False
    mem = Memory(storage=str(tmp_path / "db"), llm=llm, embedder=embedder)

    mem.remember_many(
        ["CrewAI handles complex tasks.", "Python is the best language."],
        scope="/test",
        categories=["tech"],
        importance=0.6,
    )
    mem.drain_writes()  # wait for background save
    assert mem._storage.count() == 2


def test_batch_consolidation_deduplicates_against_storage(
    tmp_path: Path,
) -> None:
    """Pre-insert a record, then remember_many with same + new content."""
    from crewai.memory.unified_memory import Memory
    from crewai.memory.analyze import ConsolidationPlan

    emb = [0.1] * 1536
    embedder = MagicMock()
    embedder.side_effect = lambda texts: [emb for _ in texts]

    llm = MagicMock()
    llm.supports_function_calling.return_value = True
    # After intra-batch dedup (identical embeddings), only 1 item survives.
    # That item hits parallel_analyze which calls analyze_for_consolidation.
    # The single-item call returns a ConsolidationPlan directly.
    llm.call.return_value = ConsolidationPlan(
        actions=[], insert_new=False, insert_reason="duplicate"
    )

    mem = Memory(storage=str(tmp_path / "db"), llm=llm, embedder=embedder)

    # Pre-insert
    from crewai.memory.types import MemoryRecord

    mem._storage.save([
        MemoryRecord(content="CrewAI is great.", scope="/test", importance=0.7, embedding=emb),
    ])
    assert mem._storage.count() == 1

    # remember_many with the same content + a new one (all identical embeddings)
    mem.remember_many(
        ["CrewAI is great.", "CrewAI is wonderful."],
        scope="/test",
        categories=["review"],
        importance=0.7,
    )
    mem.drain_writes()  # wait for background save
    # Intra-batch dedup fires: same embedding = 1.0 >= 0.98, so item 1 is dropped.
    # The remaining item finds the pre-existing record (similarity 1.0 >= 0.85).
    # LLM says don't insert -> no new records. Total stays at 1.
    assert mem._storage.count() == 1


def test_parallel_find_similar_runs_all_searches(tmp_path: Path) -> None:
    """remember_many with 3 distinct items should run 3 storage searches."""
    from unittest.mock import patch
    from crewai.memory.unified_memory import Memory

    call_count = 0

    def distinct_embedder(texts: list[str]) -> list[list[float]]:
        """Return unique embeddings per text so dedup doesn't drop them."""
        nonlocal call_count
        result = []
        for i, _ in enumerate(texts):
            emb = [0.0] * 1536
            emb[(call_count + i) % 1536] = 1.0
            result.append(emb)
        call_count += len(texts)
        return result

    embedder = MagicMock(side_effect=distinct_embedder)
    llm = MagicMock()
    llm.supports_function_calling.return_value = False
    mem = Memory(storage=str(tmp_path / "db"), llm=llm, embedder=embedder)

    with patch.object(mem._storage, "search", wraps=mem._storage.search) as search_mock:
        mem.remember_many(
            ["Alpha fact.", "Beta fact.", "Gamma fact."],
            scope="/test",
            categories=["test"],
            importance=0.5,
        )
        mem.drain_writes()  # wait for background save
        # All 3 items should trigger a storage search
        assert search_mock.call_count == 3


def test_single_remember_uses_batch_flow(tmp_path: Path, mock_embedder: MagicMock) -> None:
    """Single remember() should work through the batch flow (batch of 1)."""
    from crewai.memory.unified_memory import Memory

    llm = MagicMock()
    llm.supports_function_calling.return_value = False
    mem = Memory(storage=str(tmp_path / "db"), llm=llm, embedder=mock_embedder)

    record = mem.remember(
        "Single fact.",
        scope="/project",
        categories=["decision"],
        importance=0.8,
    )
    assert record is not None
    assert record.content == "Single fact."
    assert record.scope == "/project"
    assert record.importance == 0.8
    assert mem._storage.count() == 1


def test_parallel_analyze_runs_concurrent_calls(tmp_path: Path) -> None:
    """remember_many with 3 items needing LLM should make 3 concurrent LLM calls."""
    from unittest.mock import call
    from crewai.memory.unified_memory import Memory
    from crewai.memory.analyze import MemoryAnalysis, ExtractedMetadata

    call_count = 0

    def distinct_embedder(texts: list[str]) -> list[list[float]]:
        """Return unique embeddings per text so dedup doesn't drop them."""
        nonlocal call_count
        result = []
        for i, _ in enumerate(texts):
            emb = [0.0] * 1536
            emb[(call_count + i) % 1536] = 1.0
            result.append(emb)
        call_count += len(texts)
        return result

    embedder = MagicMock(side_effect=distinct_embedder)
    llm = MagicMock()
    llm.supports_function_calling.return_value = True
    # Return a valid MemoryAnalysis for field resolution calls
    llm.call.return_value = MemoryAnalysis(
        suggested_scope="/inferred",
        categories=["auto"],
        importance=0.6,
        extracted_metadata=ExtractedMetadata(),
    )

    mem = Memory(storage=str(tmp_path / "db"), llm=llm, embedder=embedder)

    # No scope/categories/importance -> all 3 need field resolution (Group C)
    mem.remember_many(["Fact A.", "Fact B.", "Fact C."])
    mem.drain_writes()  # wait for background save
    # Each item triggers one analyze_for_save call -> 3 parallel LLM calls
    assert llm.call.call_count == 3
    assert mem._storage.count() == 3


# --- Non-blocking save tests ---


def test_remember_many_returns_immediately(tmp_path: Path) -> None:
    """remember_many() should return an empty list immediately (non-blocking)."""
    from crewai.memory.unified_memory import Memory

    call_count = 0

    def distinct_embedder(texts: list[str]) -> list[list[float]]:
        nonlocal call_count
        result = []
        for i, _ in enumerate(texts):
            emb = [0.0] * 1536
            emb[(call_count + i) % 1536] = 1.0
            result.append(emb)
        call_count += len(texts)
        return result

    embedder = MagicMock(side_effect=distinct_embedder)
    llm = MagicMock()
    llm.supports_function_calling.return_value = False
    mem = Memory(storage=str(tmp_path / "db"), llm=llm, embedder=embedder)

    result = mem.remember_many(
        ["Fact A.", "Fact B."],
        scope="/test",
        categories=["test"],
        importance=0.5,
    )
    # Returns immediately with empty list (save is in background)
    assert result == []
    # After draining, records should exist
    mem.drain_writes()
    assert mem._storage.count() == 2


def test_recall_drains_pending_writes(tmp_path: Path, mock_embedder: MagicMock) -> None:
    """recall() should automatically wait for pending background saves."""
    from crewai.memory.unified_memory import Memory

    llm = MagicMock()
    llm.supports_function_calling.return_value = False
    mem = Memory(storage=str(tmp_path / "db"), llm=llm, embedder=mock_embedder)

    # Submit a background save
    mem.remember_many(
        ["Python is great."],
        scope="/test",
        categories=["lang"],
        importance=0.7,
    )
    # Recall should drain the pending save first, then find the record
    matches = mem.recall("Python", scope="/test", limit=5, depth="shallow")
    assert len(matches) >= 1
    assert "Python" in matches[0].record.content


def test_close_drains_and_shuts_down(tmp_path: Path, mock_embedder: MagicMock) -> None:
    """close() should drain pending saves and shut down the pool."""
    from crewai.memory.unified_memory import Memory

    llm = MagicMock()
    llm.supports_function_calling.return_value = False
    mem = Memory(storage=str(tmp_path / "db"), llm=llm, embedder=mock_embedder)

    mem.remember_many(
        ["Important fact."],
        scope="/test",
        categories=["test"],
        importance=0.9,
    )
    mem.close()
    # After close, records should be persisted
    assert mem._storage.count() == 1
