diff --git a/aieng-eval-agents/aieng/agent_evals/aml_investigation/__init__.py b/aieng-eval-agents/aieng/agent_evals/aml_investigation/__init__.py index 0d8b739..6d0ed6d 100644 --- a/aieng-eval-agents/aieng/agent_evals/aml_investigation/__init__.py +++ b/aieng-eval-agents/aieng/agent_evals/aml_investigation/__init__.py @@ -1,5 +1,6 @@ """Utilities for AML Investigation agent.""" +from .agent import create_aml_investigation_agent from .data.cases import ( AnalystOutput, CaseFile, @@ -9,14 +10,17 @@ build_cases, parse_patterns_file, ) +from .task import AmlInvestigationTask __all__ = [ + "AmlInvestigationTask", "AnalystOutput", "CaseFile", "CaseRecord", "LaunderingPattern", "GroundTruth", "build_cases", + "create_aml_investigation_agent", "parse_patterns_file", ] diff --git a/aieng-eval-agents/aieng/agent_evals/aml_investigation/agent.py b/aieng-eval-agents/aieng/agent_evals/aml_investigation/agent.py new file mode 100644 index 0000000..ac6b860 --- /dev/null +++ b/aieng-eval-agents/aieng/agent_evals/aml_investigation/agent.py @@ -0,0 +1,215 @@ +"""AML investigation agent. + +This module defines the primary factory used to build the AML investigation agent. + +The returned agent is a Google ADK ``LlmAgent`` configured to: + +- Investigate one AML case at a time. +- Use read-only SQL tools for schema discovery and data retrieval. +- Return structured output that conforms to ``AnalystOutput``. + +Examples +-------- +>>> from aieng.agent_evals.aml_investigation.agent import create_aml_investigation_agent +>>> agent = create_aml_investigation_agent() +>>> agent.name +'AmlInvestigationAnalyst' +""" + +from aieng.agent_evals.aml_investigation.data import AnalystOutput +from aieng.agent_evals.async_client_manager import AsyncClientManager +from aieng.agent_evals.langfuse import init_tracing +from google.adk.agents import LlmAgent +from google.adk.agents.base_agent import AfterAgentCallback, BeforeAgentCallback +from google.adk.agents.llm_agent import AfterModelCallback, BeforeModelCallback +from google.adk.tools.function_tool import FunctionTool +from google.genai.types import GenerateContentConfig, ThinkingConfig + + +_DEFAULT_AGENT_DESCRIPTION = "Conducts multi-step investigations for money laundering patterns using database queries." + +ANALYST_PROMPT = """\ +You are an Anti‑Money Laundering (AML) Investigation Analyst at a financial institution. \ +Your job is to investigate one case by reviewing activity in the available database and explaining whether the \ +observed behavior within the case window is consistent with money laundering or a benign explanation. + +You have access to database query tools. Use them. Do not guess or invent transactions. + +## Core Principle: Falsification +Start with the hypothesis that the case is benign. Prefer legitimate explanations unless the transaction-level evidence supports laundering. + +## Input +You will be given a JSON object with these fields: +- `case_id`: unique case identifier. +- `seed_transaction_id`: identifier for the primary transaction that triggered the case. +- `seed_timestamp`: timestamp of the seed transaction (end of the investigation window). +- `window_start`: timestamp of the beginning of the investigation window. +- `trigger_label`: upstream alert/review label or heuristic hint (may be wrong). + +### Time Scope Constraint +**Critical**: Only analyze events with timestamps between `window_start` and `seed_timestamp` (inclusive). Exclude any events after `seed_timestamp`. + +## Investigation Workflow + +### Step 1: Orient +Review the `trigger_label` as context only. Do not assume it is correct. + +### Step 2: Seed Transaction Review +- Query the seed transaction using `seed_transaction_id` +- Extract: involved parties, amounts, payment channels, instruments, and other relevant attributes + +### Step 3: Scope and Collect +Pull related activity for involved entities within the investigation window (`window_start` to `seed_timestamp`, inclusive). + +### Step 4: Assess Benign Explanations (Default Hypothesis) +Attempt to explain observed activity as legitimate first: +- State which evidence supports the benign hypothesis +- Identify what additional data would strengthen this explanation +- Only proceed to Step 5 if benign explanations are insufficient + +### Step 5: Test Laundering Hypotheses (If Needed) +If benign explanations fail to account for the evidence: +- Test whether the evidence supports known laundering typologies +- Cite concrete indicators that rule out benign explanations + +## Typologies / Heuristics +When assessing patterns, consider these typologies: +- FAN-IN (aggregation): Many sources aggregating to one destination +- FAN-OUT (dispersion): One source dispersing to many destinations +- GATHER-SCATTER / SCATTER-GATHER: Aggregation followed by dispersion (or reverse) over short time windows +- STACK / LAYERING: Multiple hops meant to obscure origin +- CYCLE: Circular fund movement +- BIPARTITE: Structured flows between two distinct groups +- RANDOM: Complex pattern with no discernible structure + +## Output Format +Return a single JSON object matching the configured output schema exactly. Populate every field. +Use `pattern_type = "NONE"` when no laundering pattern is supported by evidence in the investigation window. + +## Handling Uncertainty +If you lack sufficient information to make a determination, explicitly state what data is missing. \ +Do not fabricate transaction details or make unsupported inferences. When uncertain between benign and suspicious, \ +default to "NONE" and document why evidence is insufficient +""" + + +def create_aml_investigation_agent( + name: str = "AmlInvestigationAnalyst", + *, + description: str | None = None, + instructions: str | None = None, + temperature: float | None = None, + top_p: float | None = None, + top_k: float | None = None, + max_output_tokens: int | None = None, + presence_penalty: float | None = None, + frequency_penalty: float | None = None, + seed: int | None = None, + before_agent_callback: BeforeAgentCallback | None = None, + after_agent_callback: AfterAgentCallback | None = None, + before_model_callback: BeforeModelCallback | None = None, + after_model_callback: AfterModelCallback | None = None, + enable_tracing: bool = True, +) -> LlmAgent: + """Create a configured AML investigation agent. + + This factory builds a Google ADK ``LlmAgent`` with domain-specific instructions, + read-only SQL tools, and a strict structured output schema. + + Parameters + ---------- + name : str, default="AmlInvestigationAnalyst" + Name assigned to the agent. This name appears in traces and logs and can + help distinguish multiple agents in a shared environment. + description : str | None, optional + Optional short description of the agent's purpose. If not provided, a + default AML investigation description is used. + instructions : str | None, optional + Optional system prompt for the agent. If omitted, the module-level + ``ANALYST_PROMPT`` is used. + temperature : float | None, optional + Sampling temperature for model generation. ``None`` uses provider/model + defaults. + top_p : float | None, optional + Nucleus sampling parameter. ``None`` uses provider/model defaults. + top_k : float | None, optional + Top-k sampling parameter. ``None`` uses provider/model defaults. + max_output_tokens : int | None, optional + Maximum number of tokens the model can generate in a single response. + ``None`` uses provider/model defaults. + presence_penalty : float | None, optional + Penalty to encourage introducing new tokens. ``None`` uses + provider/model defaults. + frequency_penalty : float | None, optional + Penalty to discourage repeated tokens. ``None`` uses provider/model + defaults. + seed : int | None, optional + Optional random seed for more repeatable generations where supported by + the model/provider. + before_agent_callback : BeforeAgentCallback | None, optional + Callback executed before each agent run. + after_agent_callback : AfterAgentCallback | None, optional + Callback executed after each agent run. + before_model_callback : BeforeModelCallback | None, optional + Callback executed before each model call. + after_model_callback : AfterModelCallback | None, optional + Callback executed after each model call. + enable_tracing : bool, optional, default=True + Whether to initialize Langfuse tracing for this agent. If ``True``, Langfuse + tracing is initialized with the agent's name as the service name. + + Returns + ------- + LlmAgent + Configured AML investigation agent with: + + - Planner model from global configuration. + - Read-only SQL tools for schema and query execution. + - ``AnalystOutput`` as the enforced response schema. + - Reasoning/thought collection enabled through thinking config. + + Examples + -------- + >>> # Build the agent with defaults: + >>> agent = create_aml_investigation_agent() + >>> isinstance(agent, LlmAgent) + True + >>> # Build the agent with a custom name and deterministic settings: + >>> agent = create_aml_investigation_agent( + ... name="aml_eval_agent", + ... temperature=0.0, + ... seed=42, + ... ) + >>> agent.name + 'aml_eval_agent' + """ + # Get the client manager singleton instance + client_manager = AsyncClientManager.get_instance() + db = client_manager.aml_db(agent_name=name) + + # Initialize tracing if enabled and a name is provided + if enable_tracing: + init_tracing(service_name=name) + + return LlmAgent( + name=name, + description=description or _DEFAULT_AGENT_DESCRIPTION, + before_agent_callback=before_agent_callback, + after_agent_callback=after_agent_callback, + model=client_manager.configs.default_planner_model, + instruction=instructions or ANALYST_PROMPT, + tools=[FunctionTool(db.get_schema_info), FunctionTool(db.execute)], + generate_content_config=GenerateContentConfig( + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_output_tokens=max_output_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + seed=seed, + thinking_config=ThinkingConfig(include_thoughts=True), + ), + output_schema=AnalystOutput, + before_model_callback=before_model_callback, + after_model_callback=after_model_callback, + ) diff --git a/aieng-eval-agents/aieng/agent_evals/aml_investigation/task.py b/aieng-eval-agents/aieng/agent_evals/aml_investigation/task.py new file mode 100644 index 0000000..d3878e9 --- /dev/null +++ b/aieng-eval-agents/aieng/agent_evals/aml_investigation/task.py @@ -0,0 +1,160 @@ +"""Task function for AML investigation experiment execution. + +This module provides a Langfuse-compatible task callable that executes the AML +investigation agent on one dataset item and returns a structured analyst output. + +The task is designed for use with the evaluation harness and Langfuse +``run_experiment`` APIs. It handles: + +- Input normalization for both dict and dataset item objects. +- Running the ADK agent through a shared ``Runner``. +- Extracting and validating final model output. +- Returning consistent ``dict`` results for evaluator consumption. + +Examples +-------- +>>> import asyncio +>>> from aieng.agent_evals.aml_investigation.task import AmlInvestigationTask +>>> task = AmlInvestigationTask() +>>> # Run one AML case in an async context +>>> sample_item = { +... "input": { +... "case_id": "case-001", +... "seed_transaction_id": "txn-001", +... "seed_timestamp": "2022-09-01T12:00:00", +... "window_start": "2022-09-01T00:00:00", +... "trigger_label": "RANDOM_REVIEW", +... } +... } +>>> _ = asyncio.run(task(item=sample_item)) +>>> # Use the task in an experiment +>>> from aieng.agent_evals.evaluation.experiment import run_experiment +>>> result = run_experiment( +... dataset_name="aml_eval_dataset", +... name="AML Investigation Evaluation", +... task=AmlInvestigationTask(), +... evaluators=[...], +... ) +""" + +import getpass +import json +import logging +import uuid +from typing import Any + +from aieng.agent_evals.aml_investigation.agent import create_aml_investigation_agent +from aieng.agent_evals.aml_investigation.data import AnalystOutput +from aieng.agent_evals.async_client_manager import AsyncClientManager +from google.adk.agents import LlmAgent +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types +from langfuse.experiment import ExperimentItem + + +logger = logging.getLogger(__name__) + + +class AmlInvestigationTask: + """Langfuse-compatible task wrapper for AML case investigations. + + This class implements the ``TaskFunction`` callable protocol expected by + Langfuse experiments: ``__call__(*, item, **kwargs)``. + + A single task instance owns: + + - One AML investigation agent. + - One ADK runner used to execute agent calls. + + Parameters + ---------- + agent : LlmAgent | None, optional + Pre-configured AML investigation agent to use. If ``None``, the default + factory ``create_aml_investigation_agent()`` is used. + + Examples + -------- + >>> # Create a task with the default agent: + >>> task = AmlInvestigationTask() + >>> isinstance(task, AmlInvestigationTask) + True + >>> # Create a task with a custom agent: + >>> from aieng.agent_evals.aml_investigation import create_aml_investigation_agent + >>> custom_agent = create_aml_investigation_agent(name="aml_custom") + >>> task = AmlInvestigationTask(agent=custom_agent) + """ + + def __init__(self, *, agent: LlmAgent | None = None) -> None: + """Initialize the AML task with an agent and runner.""" + self._agent = agent or create_aml_investigation_agent() + self._runner = Runner( + app_name="aml_investigation", + agent=self._agent, + session_service=InMemorySessionService(), + auto_create_session=True, + ) + + async def __call__(self, *, item: ExperimentItem, **kwargs: Any) -> dict[str, Any] | None: + """Run one AML investigation case and return structured output. + + Parameters + ---------- + item : ExperimentItem + One Langfuse experiment item. This can be: + + - A dict-like local item with an ``"input"`` key. + - A Langfuse dataset item object with an ``input`` attribute. + + The input payload is serialized to JSON and passed as the user + message to the agent. + **kwargs : Any + Additional keyword arguments forwarded by Langfuse. They are + accepted for protocol compatibility and ignored by this task. + + Returns + ------- + dict[str, Any] | None + Parsed analyst output as a dictionary if a valid final response was + produced, otherwise ``None``. + + Notes + ----- + The method first attempts strict schema parsing with + ``AnalystOutput.model_validate_json``. If that fails, it falls back to a + direct ``json.loads`` parse and validates the resulting object. + """ + item_input = item.get("input") if isinstance(item, dict) else item.input + serialized_input = json.dumps(item_input, ensure_ascii=False, indent=2) + message = types.Content(parts=[types.Part(text=serialized_input)], role="user") + + final_text: str | None = None + async for event in self._runner.run_async( + session_id=str(uuid.uuid4()), user_id=getpass.getuser(), new_message=message + ): + if event.is_final_response() and event.content and event.content.parts: + final_text = "".join(part.text or "" for part in event.content.parts if part.text) + + if not final_text: + metadata = item.get("metadata", {}) if isinstance(item, dict) else item.metadata + case_id = metadata.get("id") if metadata else "unknown" + logger.warning("No analyst output produced for case_id=%s", case_id) + return None + + # Prefer strict schema parse first if output_schema is respected. + try: + return AnalystOutput.model_validate_json(final_text.strip()).model_dump() + except Exception: + # fallback: extract JSON substring if needed + return AnalystOutput.model_validate(json.loads(final_text)).model_dump() + + async def close(self) -> None: + """Close runner and database connections used by this task instance. + + Notes + ----- + This method should be called when the task instance is no longer needed, + especially in long-running processes or repeated evaluation runs. + """ + await self._runner.close() + AsyncClientManager.get_instance().aml_db().close() diff --git a/implementations/aml_investigation/README.md b/implementations/aml_investigation/README.md index c1b0807..3f2c9df 100644 --- a/implementations/aml_investigation/README.md +++ b/implementations/aml_investigation/README.md @@ -71,28 +71,26 @@ Each case contains: ## Run the Agent (Batch) -This reads `aml_cases.jsonl`, runs the agent over any cases missing `analysis`, and writes: +This reads `aml_cases.jsonl`, runs the agent over any cases missing `output`, and writes: -- `implementations/aml_investigation/data/aml_cases_with_analysis.jsonl` +- `implementations/aml_investigation/data/aml_cases_with_output.jsonl` ```bash -uv run --env-file .env implementations/aml_investigation/agent.py +uv run --env-file .env implementations/aml_investigation/cli.py ``` -The script prints a simple confusion matrix for `is_laundering` based on the cases that have `analysis`. +The script prints a simple confusion matrix for `is_laundering` based on the cases that have `output`. ## Run with ADK Web UI If you want to inspect the agent interactively, the module exposes a top-level `root_agent` for ADK discovery. -From `implementations/aml_investigation/`: +Run: ```bash uv run adk web --port 8000 --reload --reload_agents implementations/ ``` -The DB tool is initialized lazily when a tool call happens (so importing the module doesn’t keep a DB connection open). - ## Safety Notes (Why Read‑Only SQL?) Agents' access to operational databases should be limited to prevent accidental or malicious data modification. diff --git a/implementations/aml_investigation/agent.py b/implementations/aml_investigation/agent.py index f3a2b6e..013b04a 100644 --- a/implementations/aml_investigation/agent.py +++ b/implementations/aml_investigation/agent.py @@ -1,297 +1,21 @@ -"""AML Investigation Agent Implementation. +"""ADK discovery entrypoint for the AML investigation demo. + +It exposes a module-level ``root_agent`` so ``adk web`` can discover it. Examples -------- -Run the agent on AML cases from JSONL file: - uv run --env-file .env implementations/aml_investigation/agent.py -Run with adk web: +Run with ``adk web``: uv run adk web --port 8000 --reload --reload_agents implementations/ """ -import asyncio -import getpass -import json import logging -import os -import uuid -from functools import lru_cache, partial -from pathlib import Path -import google.genai.types -from aieng.agent_evals.aml_investigation.data import AnalystOutput, CaseRecord -from aieng.agent_evals.async_client_manager import AsyncClientManager -from aieng.agent_evals.async_utils import rate_limited -from aieng.agent_evals.tools import ReadOnlySqlDatabase -from dotenv import load_dotenv -from google.adk.agents import Agent -from google.adk.runners import Runner -from google.adk.sessions import InMemorySessionService -from google.adk.tools import FunctionTool -from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn +from aieng.agent_evals.aml_investigation.agent import create_aml_investigation_agent +logging.basicConfig(level=logging.INFO, format="%(message)s") logger = logging.getLogger(__name__) -load_dotenv() - -MAX_CONCURRENT_CASES = 5 # Limit for concurrent case analyses - -ANALYST_PROMPT = """\ -You are an Anti‑Money Laundering (AML) Investigation Analyst at a financial institution. -Your job is to investigate one case by reviewing activity in the available database and explaining whether the -observed behavior within the case window is consistent with money laundering or a benign explanation. - -You have access to database query tools. Use them. Do not guess or invent transactions. - -## Core Principle: Falsification -Start with the hypothesis that the case is benign. Prefer legitimate explanations unless the transaction-level evidence -supports laundering. - -## Input -You will be given a JSON object with these fields: -- `case_id`: unique case identifier. -- `seed_transaction_id`: identifier for the primary transaction that triggered the case. -- `seed_timestamp`: timestamp of the seed transaction (end of the investigation window). -- `window_start`: timestamp of the beginning of the investigation window. -- `trigger_label`: upstream alert/review label or heuristic hint (may be wrong). - -### Time Scope Rule (Strict) -Only analyze events with `timestamp` between `window_start` and `seed_timestamp` (inclusive). -Do not use events after `seed_timestamp`. - -## Investigation Workflow -1) **Orient** - - Treat `trigger_label` as context only. Do not assume it is correct. -2) **Seed review** - - Query the seed event/transaction using `seed_transaction_id`. - - Extract key attributes available in this database (e.g., involved parties, amounts, payment channel/instrument). -3) **Scope and collect** - - Pull related activity for involved entities between `window_start` and `seed_timestamp` (inclusive). -4) **Assess benign explanations (default)** - - Try to explain the observed activity as legitimate first. - - State what evidence supports the benign hypothesis and what data would be needed to strengthen it. -5) **Test laundering hypotheses (only if needed)** - - Only if benign explanations are insufficient, test whether the evidence supports laundering typologies or other - suspicious behavior. - - Cite the concrete indicators that rule out benign explanations. - -## Typologies / Heuristics -Look for transaction patterns consistent with laundering typologies, such as: -- FAN-IN (aggregation): many sources to one destination -- FAN-OUT (dispersion): one source to many destinations -- GATHER-SCATTER / SCATTER-GATHER: aggregation then dispersion (or vice‑versa), often over short time windows. -- STACK / LAYERING: multiple hops meant to obscure origin -- CYCLE: circular movement -- RANDOM: complext pattern -- BIPARTITE: structured flows between two groups - -## Output Format -Return a single JSON object that matches the configured output schema exactly. Fill every field. -Use `pattern_type = "NONE"` when no laundering pattern is supported by evidence in the investigation window. -""" - - -@lru_cache(maxsize=1) -def _get_db() -> ReadOnlySqlDatabase: - """Lazily construct the read-only database tool from environment configuration.""" - client_manager = AsyncClientManager().get_instance() - return client_manager.aml_db() - - -async def _try_close_db() -> None: - """Close the lazily initialized database tool if it was created.""" - if _get_db.cache_info().currsize: - client_manager = AsyncClientManager().get_instance() - await client_manager.close() - _get_db.cache_clear() - # ADK discovery expects a module-level `root_agent` -root_agent = Agent( - name="AmlInvestigationAnalyst", - description="Conducts multi-step financial crime investigations using database queries.", - tools=[FunctionTool(_get_db().get_schema_info), FunctionTool(_get_db().execute)], - model="gemini-3-flash-preview", - instruction=ANALYST_PROMPT, - output_schema=AnalystOutput, - generate_content_config=google.genai.types.GenerateContentConfig( - thinking_config=google.genai.types.ThinkingConfig(include_thoughts=True) - ), -) - - -def _load_records(path: Path) -> list[CaseRecord]: - """Load CaseRecord rows from a JSONL file, skipping invalid lines.""" - if not path.exists(): - return [] - - records: list[CaseRecord] = [] - with path.open("r", encoding="utf-8") as file: - for line_number, line in enumerate(file, start=1): - stripped_line = line.strip() - if not stripped_line: - continue - try: - records.append(CaseRecord.model_validate_json(stripped_line)) - except Exception as exc: - logger.warning("Skipping invalid JSONL record at %s:%d (%s)", path, line_number, exc) - return records - - -def _extract_json(text: str) -> dict: - """Parse JSON from model output, falling back to the first JSON object substring.""" - try: - return json.loads(text) - except json.JSONDecodeError: - start = text.find("{") - end = text.rfind("}") - if start == -1 or end == -1 or end <= start: - raise - return json.loads(text[start : end + 1]) - - -def _write_results(output_path: Path, input_records: list[CaseRecord], results_by_id: dict[str, CaseRecord]) -> int: - """Rewrite the output JSONL with updated analyses, preserving input order.""" - tmp_path = output_path.with_suffix(output_path.suffix + ".tmp") - written: set[str] = set() - analyzed = 0 - - with tmp_path.open("w", encoding="utf-8") as outfile: - for record in input_records: - case_id = record.input.case_id - if case_id in written: - continue - written.add(case_id) - out_record = results_by_id.get(case_id, record) - analyzed += int(out_record.output is not None) - outfile.write(out_record.model_dump_json() + "\n") - - tmp_path.replace(output_path) - return analyzed - - -async def _analyze_case(runner: Runner, record: CaseRecord) -> CaseRecord: - """Run the agent on one case and attach the validated AnalystOutput.""" - message = google.genai.types.Content( - role="user", parts=[google.genai.types.Part(text=record.input.model_dump_json())] - ) - events_async = runner.run_async(session_id=str(uuid.uuid4()), user_id=getpass.getuser(), new_message=message) - - final_text: str | None = None - async for event in events_async: - if event.is_final_response() and event.content and event.content.parts: - final_text = "".join(part.text or "" for part in event.content.parts if part.text) - - if not final_text: - logger.warning("No analyst output produced for case_id=%s", record.input.case_id) - return record - - record.output = AnalystOutput.model_validate(_extract_json(final_text.strip())) - return record - - -async def _safe_analyze_case(runner: Runner, record: CaseRecord) -> CaseRecord: - """Analyze a case and swallow exceptions so batch runs continue.""" - try: - return await _analyze_case(runner, record) - except Exception as exc: - logger.exception("Case failed (case_id=%s): %s", record.input.case_id, exc) - return record - - -async def _analyze_cases_to_jsonl( - runner: Runner, - cases: list[CaseRecord], - semaphore: asyncio.Semaphore, - output_path: Path, -) -> dict[str, CaseRecord]: - """Analyze cases concurrently and append each result to a JSONL output file.""" - if not cases: - return {} - - output_path.parent.mkdir(parents=True, exist_ok=True) - tasks = [ - asyncio.create_task(rate_limited(partial(_safe_analyze_case, runner, record), semaphore)) for record in cases - ] - - analyzed_by_id: dict[str, CaseRecord] = {} - with ( - output_path.open("a", encoding="utf-8") as outfile, - Progress( - TextColumn("[bold blue]{task.description}"), - BarColumn(), - TextColumn("{task.completed}/{task.total}"), - TimeRemainingColumn(), - TimeElapsedColumn(), - ) as progress, - ): - progress_task = progress.add_task("Analyzing AML cases", total=len(tasks)) - - for finished in asyncio.as_completed(tasks): - record = await finished - analyzed_by_id[record.input.case_id] = record - outfile.write(record.model_dump_json() + "\n") - outfile.flush() - os.fsync(outfile.fileno()) - progress.update(progress_task, advance=1) - - return analyzed_by_id - - -async def _main() -> None: - """Run the AML investigation agent on cases from JSONL.""" - input_path = Path("implementations/aml_investigation/data/aml_cases.jsonl") - if not input_path.exists(): - raise FileNotFoundError(f"Case JSONL not found at {input_path.resolve()}") - - output_path = input_path.with_name("aml_cases_with_analysis.jsonl") - output_path.parent.mkdir(parents=True, exist_ok=True) - - input_records = _load_records(input_path) - existing_results = {record.input.case_id: record for record in _load_records(output_path)} - to_run = [r for r in input_records if existing_results.get(r.input.case_id, r).output is None] - - logger.info("Resume: %d/%d done; %d remaining.", len(input_records) - len(to_run), len(input_records), len(to_run)) - - try: - runner = Runner( - app_name="aml_investigation", - agent=root_agent, - session_service=InMemorySessionService(), - auto_create_session=True, - ) - analyzed_by_id = await _analyze_cases_to_jsonl( - runner, to_run, asyncio.Semaphore(MAX_CONCURRENT_CASES), output_path - ) - existing_results.update(analyzed_by_id) - analyzed_count = _write_results(output_path, input_records, existing_results) - logger.info("Wrote %d analyzed cases to %s", analyzed_count, output_path) - - final_records = [existing_results.get(r.input.case_id, r) for r in input_records] - scored = [r for r in final_records if r.output is not None] - if not scored: - logger.info("Metrics: N/A (no analyzed cases)") - else: - tp = fp = fn = tn = 0 - for r in scored: - gt = r.expected_output.is_laundering - assert r.output is not None # Guaranteed by filter above - pred = r.output.is_laundering - if gt and pred: - tp += 1 - elif (not gt) and pred: - fp += 1 - elif gt and (not pred): - fn += 1 - else: - tn += 1 - logger.info("is_laundering confusion matrix:") - logger.info(" TP=%d FP=%d", tp, fp) - logger.info(" FN=%d TN=%d", fn, tn) - finally: - await _try_close_db() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, format="%(message)s") - asyncio.run(_main()) +root_agent = create_aml_investigation_agent(enable_tracing=True) diff --git a/implementations/aml_investigation/cli.py b/implementations/aml_investigation/cli.py new file mode 100644 index 0000000..15ba040 --- /dev/null +++ b/implementations/aml_investigation/cli.py @@ -0,0 +1,271 @@ +r"""Run AML cases from JSONL. + +This module provides a CLI for running the AML investigation agent over cases +defined in a JSONL file. The workflow is: + +1. Read input cases from JSONL. +2. Run the AML task over pending cases. Save intermediate results and show progress + while cases run. +3. Write results back to JSONL. +4. Print a simple confusion matrix. + +Examples +-------- +Run with defaults: + uv run --env-file .env implementations/aml_investigation/cli.py + +Run with custom settings: + uv run --env-file .env implementations/aml_investigation/cli.py \ + --input-path implementations/aml_investigation/data/aml_cases.jsonl \ + --output-path implementations/aml_investigation/data/aml_cases_with_output.jsonl \ + --max-concurrent-cases 8 \ + --resume +""" + +import asyncio +import logging +import os +from pathlib import Path + +import click +from aieng.agent_evals.aml_investigation.agent import create_aml_investigation_agent +from aieng.agent_evals.aml_investigation.data import AnalystOutput, CaseRecord +from aieng.agent_evals.aml_investigation.task import AmlInvestigationTask +from aieng.agent_evals.progress import create_progress +from langfuse.experiment import LocalExperimentItem +from rich.logging import RichHandler + + +logging.basicConfig(level=logging.INFO, format="%(message)s", handlers=[RichHandler(show_path=False)], force=True) +logger = logging.getLogger(__name__) + +DEFAULT_INPUT_PATH = Path("implementations/aml_investigation/data/aml_cases.jsonl") +DEFAULT_OUTPUT_FILENAME = "aml_cases_with_output.jsonl" +DEFAULT_MAX_CONCURRENT_CASES = 10 + + +def _load_case_records(path: Path) -> list[CaseRecord]: + """Load case records from a JSONL file. + + Invalid rows are skipped with a warning. + """ + if not path.exists(): + return [] + + records: list[CaseRecord] = [] + with path.open("r", encoding="utf-8") as file: + for line_number, line in enumerate(file, start=1): + stripped = line.strip() + if not stripped: + continue + try: + records.append(CaseRecord.model_validate_json(stripped)) + except Exception as exc: + logger.warning("Skipping invalid JSONL row at %s:%d (%s)", path, line_number, exc) + return records + + +def _write_case_records(path: Path, records: list[CaseRecord]) -> None: + """Write case records to JSONL.""" + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as file: + for record in records: + file.write(record.model_dump_json() + "\n") + + +def _merge_records_in_input_order( + input_records: list[CaseRecord], updates_by_id: dict[str, CaseRecord] +) -> list[CaseRecord]: + """Merge updates back into the original input order. + + If duplicate ``case_id`` values exist in input, only the first instance is kept. + """ + merged: list[CaseRecord] = [] + seen: set[str] = set() + for record in input_records: + case_id = record.input.case_id + if case_id in seen: + continue + seen.add(case_id) + merged.append(updates_by_id.get(case_id, record)) + return merged + + +def _log_is_laundering_confusion_matrix(records: list[CaseRecord]) -> None: + """Log a simple confusion matrix using records that have predictions.""" + scored = [record for record in records if record.output is not None] + if not scored: + logger.info("Metrics: N/A (no analyzed cases)") + return + + tp = fp = fn = tn = 0 + for record in scored: + gt = record.expected_output.is_laundering + assert record.output is not None + pred = record.output.is_laundering + if gt and pred: + tp += 1 + elif (not gt) and pred: + fp += 1 + elif gt and (not pred): + fn += 1 + else: + tn += 1 + + logger.info("is_laundering confusion matrix:") + logger.info(" TP=%d FP=%d", tp, fp) + logger.info(" FN=%d TN=%d", fn, tn) + + +async def _run_case(task: AmlInvestigationTask, record: CaseRecord, semaphore: asyncio.Semaphore) -> CaseRecord: + """Run one case. + + This function is intentionally defensive: it logs errors and always returns + a ``CaseRecord`` so the full batch can continue. + """ + try: + async with semaphore: + item: LocalExperimentItem = {"input": record.input.model_dump(), "metadata": {"id": record.input.case_id}} + output = await task(item=item) + + if output is None: + logger.warning("No analyst output produced for case_id=%s", record.input.case_id) + return record + + record.output = AnalystOutput.model_validate(output) + return record + except Exception as exc: + logger.exception("Case failed (case_id=%s): %s", record.input.case_id, exc) + return record + + +async def run_cases( + input_path: Path = DEFAULT_INPUT_PATH, + output_path: Path | None = None, + max_concurrent_cases: int = DEFAULT_MAX_CONCURRENT_CASES, + resume: bool = True, +) -> Path: + """Run AML investigations for cases from an input JSONL file. + + Parameters + ---------- + input_path : Path, optional + Input case JSONL path. + output_path : Path | None, optional + Output JSONL path. If ``None``, a default filename is used in the input + directory. + max_concurrent_cases : int, optional + Maximum number of cases to process at the same time. + resume : bool, optional + If ``True``, cases already analyzed in the output file are skipped. + + Returns + ------- + Path + Final output path. + """ + if max_concurrent_cases <= 0: + raise ValueError("max_concurrent_cases must be > 0") + if not input_path.exists(): + raise FileNotFoundError(f"Case JSONL file not found at {input_path.resolve()}") + + resolved_output_path = output_path or input_path.with_name(DEFAULT_OUTPUT_FILENAME) + if not resume and resolved_output_path.exists(): + resolved_output_path.unlink() + resolved_output_path.parent.mkdir(parents=True, exist_ok=True) + + input_records = _load_case_records(input_path) + existing_by_id: dict[str, CaseRecord] = {} + if resume: + for record in _load_case_records(resolved_output_path): + existing_by_id[record.input.case_id] = record + + if resume: + pending = [ + record for record in input_records if existing_by_id.get(record.input.case_id, record).output is None + ] + logger.info( + "Resume: %d/%d done; %d remaining.", len(input_records) - len(pending), len(input_records), len(pending) + ) + else: + pending = input_records + logger.info("Running %d/%d cases from scratch.", len(pending), len(input_records)) + + semaphore = asyncio.Semaphore(max_concurrent_cases) + agent = create_aml_investigation_agent(enable_tracing=True) + task_runner = AmlInvestigationTask(agent=agent) + tasks = [asyncio.create_task(_run_case(task_runner, record, semaphore)) for record in pending] + try: + with resolved_output_path.open("a", encoding="utf-8") as checkpoint_file, create_progress() as progress: + progress_task = progress.add_task("Analyzing AML cases", total=len(tasks)) + for finished in asyncio.as_completed(tasks): + record = await finished + existing_by_id[record.input.case_id] = record + + # Save each completed case immediately so resume works after + # cancellation/crash and we do not waste API calls + checkpoint_file.write(record.model_dump_json() + "\n") + checkpoint_file.flush() + os.fsync(checkpoint_file.fileno()) + + progress.update(progress_task, advance=1) + except asyncio.CancelledError: + logger.warning("Run cancelled. Partial results are saved in %s", resolved_output_path) + raise + finally: + await task_runner.close() + + final_records = _merge_records_in_input_order(input_records, existing_by_id) + _write_case_records(resolved_output_path, final_records) + logger.info("Wrote %d analyzed cases to %s", sum(r.output is not None for r in final_records), resolved_output_path) + + _log_is_laundering_confusion_matrix(final_records) + return resolved_output_path + + +@click.command() +@click.option( + "--input-path", + type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path), + default=DEFAULT_INPUT_PATH, + show_default=True, + help="Input case JSONL file.", +) +@click.option( + "--output-path", + type=click.Path(dir_okay=False, writable=True, path_type=Path), + default=None, + help="Output JSONL file. Defaults to aml_cases_with_output.jsonl next to input.", +) +@click.option( + "--max-concurrent-cases", + type=int, + default=DEFAULT_MAX_CONCURRENT_CASES, + show_default=True, + help="Maximum number of cases to run at once.", +) +@click.option( + "--resume/--no-resume", + default=True, + show_default=True, + help="Skip cases that already have analysis in the output file.", +) +def cli(input_path: Path, output_path: Path | None, max_concurrent_cases: int, resume: bool) -> None: + """Run the AML demo workflow over example cases.""" + if max_concurrent_cases <= 0: + raise click.BadParameter("max-concurrent-cases must be > 0", param_hint="--max-concurrent-cases") + + logging.basicConfig(level=logging.INFO, format="%(message)s") + final_path = asyncio.run( + run_cases( + input_path=input_path, + output_path=output_path, + max_concurrent_cases=max_concurrent_cases, + resume=resume, + ) + ) + click.echo(f"Done. Output written to {final_path}") + + +if __name__ == "__main__": + cli()