From 288980bf612f9559b2bde9abbc76dde3ffbae8b9 Mon Sep 17 00:00:00 2001 From: Chris Zubak-Skees Date: Thu, 18 Dec 2025 18:27:08 +0800 Subject: [PATCH] Add script to optimize experimental LLM matching of application fields to base fields --- standalone/.gitignore | 8 + standalone/data/optimized_rag.json | 32 ++ standalone/pdc-optimize-field-matching | 421 +++++++++++++++++++++++++ 3 files changed, 461 insertions(+) create mode 100644 standalone/.gitignore create mode 100644 standalone/data/optimized_rag.json create mode 100755 standalone/pdc-optimize-field-matching diff --git a/standalone/.gitignore b/standalone/.gitignore new file mode 100644 index 0000000..db08b29 --- /dev/null +++ b/standalone/.gitignore @@ -0,0 +1,8 @@ +.DS_Store + +# MLFlow +mlruns/ +mlflow.db + +data/* +!data/optimized_rag.json diff --git a/standalone/data/optimized_rag.json b/standalone/data/optimized_rag.json new file mode 100644 index 0000000..48b1235 --- /dev/null +++ b/standalone/data/optimized_rag.json @@ -0,0 +1,32 @@ +{ + "respond": { + "traces": [], + "train": [], + "demos": [], + "signature": { + "instructions": "Which short codes best correspond to the given label? Only respond with 1-3 short codes from the provided list. If none are a good fit, respond with 'unknown'.\n\nImportant guidelines:\n1. Return only the plain short code text without any additional formatting, quotation marks, or special characters like «»\n2. The short code should match exactly as it appears in the context list\n3. Look for short codes that semantically match the label - consider both exact phrase matches and conceptual matches\n4. Pay attention to the full context of the label - for example, \"Type\" alone is ambiguous, but if the domain context suggests it refers to proposal types rather than organization types, prioritize accordingly\n5. When the label is ambiguous (like \"Type\"), consider that the short code might not be in the provided list - in such cases, respond with 'unknown' rather than guessing from tangentially related options\n6. Match the components of compound labels carefully - for example, \"Proposal Project Director LastName\" should map to a short code containing all these elements: proposal, project_director, and last_name\n7. Prioritize exact semantic matches over partial matches\n\nDomain-specific reasoning strategies:\n8. Distinguish between different scopes: 'proposal' prefix typically refers to project-specific information, while 'organization' prefix refers to institution-level information. For example, \"Project Budget\" should map to proposal-level budget fields, not organization-level ones.\n9. When a label uses general terms that could apply to multiple contexts (like \"news articles\" or \"reports about work\"), these rarely match standard organizational or proposal fields. Don't force a match with tangentially related codes like \"lobbying_activities\" just because they mention activities - instead, respond with 'unknown' if no code semantically captures the specific request.\n10. Prefer more specific codes over general ones when both exist. For instance, if both 'proposal_project_budget' and 'proposal_budget' are available and the label says \"Project Budget\", both may be valid matches as they're semantically equivalent in this context.\n11. Be conservative: only return codes that genuinely match the label's semantic meaning. If the label describes something that isn't represented in the available codes, always respond with 'unknown' rather than selecting a loosely related option.\n\nCritical context-awareness for date fields:\n12. When the label refers to a \"Complete Date\" or completion-related timing, carefully consider the workflow context. In review or application workflows, 'review_close_date' or 'review_end_date' typically indicates when a review process is complete, which is semantically different from 'proposal_end_date' (which refers to when a project ends). The word \"Complete\" in isolation often refers to when a process or review is finished, not when a project concludes.\n13. Distinguish between different types of dates: submission dates, review completion dates, project start/end dates, and fiscal dates. Match based on what action or milestone is being completed.\n14. 'proposal_date' typically refers to the date a proposal was submitted or created, not when something is completed.\n\nAdditional date field reasoning:\n15. In grant/proposal management systems, there are distinct lifecycle stages: proposal submission, review period, and project execution. Each has its own dates. Match the label to the appropriate stage based on semantic meaning.", + "fields": [ + { + "prefix": "Context:", + "description": "list of available short codes" + }, + { + "prefix": "Label:", + "description": "the label to match" + }, + { + "prefix": "Short Codes:", + "description": "Up to 3 short code that are likely matches for the label" + } + ] + }, + "lm": null + }, + "metadata": { + "dependency_versions": { + "python": "3.12", + "dspy": "3.0.4", + "cloudpickle": "3.1" + } + } +} diff --git a/standalone/pdc-optimize-field-matching b/standalone/pdc-optimize-field-matching new file mode 100755 index 0000000..5bd5b00 --- /dev/null +++ b/standalone/pdc-optimize-field-matching @@ -0,0 +1,421 @@ +#!/usr/bin/env pipx run + +# Copyright (C) PhilanthropyDataCommons.org +# GNU Affero General Public License, Version 3 (AGPL-3.0) + + +# Uses PEP 723 inline script metadata https://packaging.python.org/en/latest/specifications/inline-script-metadata/ +# to auto-install dependencies in dedicated environment on script execution with `pipx run`` or `uv run``. +# pipx is the default script runner, so `./pdc-form-matching` should work if pipx is installed, but +# manual Python env setup also still works, just run: +# `$ python -m venv venv` +# `$ source ./venv/bin/activate` +# `$ pip3 install` the following: +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "docling>=2.64.0", +# "dspy>=3.0.4", +# "faiss-cpu>=1.13.0", +# "mlflow>=3.6.0", +# "oauthlib>=3.3.1", +# "polars>=1.35.2", +# "pydantic>=2.12.5", +# "python-dotenv>=1.2.1", +# "requests-oauthlib>=2.0.0", +# "sentence-transformers>=5.2.0", +# "typer>=0.19.2", +# ] +# /// + +__doc__ = """Optimize a prompt to match form fields to base fields using retrieval augmented generation. + +This script uses DSPy to build a retrieval-augmented generation (RAG) module that matches form field labels from +application forms to base field short codes in the Philanthropy Data Commons (PDC). The script retrieves application +form data and base fields from the PDC API, extracts form fields from a document, combines mappings from various sources, +generates candidate short codes using a RAG model, evaluates the system's performance and optimizes its prompt using GEPA. + +Pre-requisites: +- `pipx` (can also use a virtual environment with the dependencies installed manually or `uv run`) + +Setup: +1. From the project root run: `cp .env.example .env` +2. Fill in `.env` with PDC credentials capable of retrieving form data with a client grant (can use data-ingest credentials) +3. Add `ANTHROPIC_API_KEY=your_anthropic_api_key_here` to the `.env` file + +Usage: +1. `cd ./standalone` +2. `./pdc-optimize-field-matching retrieve-base-fields` to download `data/baseFields.json` from the PDC API +3. `./pdc-optimize-field-matching retrieve-application-forms` to download application forms from the PDC API +4. `./pdc-optimize-field-matching transform-application-forms` to populate `data/non-matching-mappings.csv` with application form fields with labels that don't exactly match their associated base field labels. +5. `./pdc-optimize-field-matching extract-form-fields "data/2025 Stage 1 - LOI FINAL.docx"` to extract form fields from a sample application form document (you need to provide a document file). +6. Download or create a CSV called `hand-collected-mappings.csv` in `data/` with at least two columns: `label` and `instructions` and any additional mappings you have collected manually. +7. `./pdc-optimize-field-matching combine-mappings` to combine all three sources of data into `data/combined-mappings.csv` +8. `./pdc-optimize-field-matching generate-candidates` to generate candidate short codes for each label in `data/combined-mappings.csv` and save the results to `data/mappings-with-possible-short-codes.csv` +9. Manually review and edit `data/mappings-with-possible-short-codes.csv` and move any `possible_short_code` values that are correct into the `short_code` column, saving the results to `data/edited-mappings.csv` +10. `./pdc-optimize-field-matching split-datasets` to create training and test datasets from `data/edited-mappings.csv` +11. `./pdc-optimize-field-matching evaluate-rag` to evaluate the retrieval-augmented generation model on the test dataset. +12. `./pdc-optimize-field-matching optimize-rag` to optimize the retrieval-augmented generation model using the training dataset. +13. `./pdc-optimize-field-matching evaluate-rag` to see the improved evaluation results. + +Can observe MLFlow traces by running `pipx mlflow ui --backend-store-uri "sqlite:///mlflow.db" --default-artifact-root "mlruns"` and navigating to `http://localhost:5000` in your browser. +""" + +import json +import logging +import os +from pathlib import Path +from typing import Any + +import dspy +import mlflow +import polars as pl +import pydantic +import requests +import typer +from docling.document_converter import DocumentConverter +from dotenv import load_dotenv +from oauthlib.oauth2 import BackendApplicationClient +from requests_oauthlib import OAuth2Session +from sentence_transformers import SentenceTransformer + +logger = logging.getLogger(__name__) + +logging.getLogger("mlflow").setLevel(logging.WARNING) +logging.getLogger("alembic").setLevel(logging.WARNING) + +# Load environment variables +load_dotenv() + +DS_OIDC_BASE_URL = os.getenv("DS_OIDC_BASE_URL") +DS_OIDC_CLIENT_ID = os.getenv("DS_OIDC_CLIENT_ID") +DS_OIDC_CLIENT_SECRET = os.getenv("DS_OIDC_CLIENT_SECRET") +DS_PDC_API_BASE_URL = os.getenv("DS_PDC_API_BASE_URL") + +# Paths and configurations +base_path = Path("data/") + +optimized_rag_path = base_path / "optimized_rag.json" + +# TODO: make configurable +model = "anthropic/claude-sonnet-4-5-20250929" +embedding_model_name = "Snowflake/snowflake-arctic-embed-m-v1.5" + +topk_docs_to_retrieve = 10 # number of documents to retrieve per search query + +# MLFlow setup +tracking_uri = "sqlite:///mlflow.db" +mlflow.set_tracking_uri(tracking_uri) +mlflow.set_experiment("DSPy") + +mlflow.dspy.autolog() + +# CLI setup +app = typer.Typer() + +base_path.mkdir(exist_ok=True) + +# LLM configuration +lm = dspy.LM(model) +dspy.configure(lm=lm) + +# OAuth setup +client = BackendApplicationClient(client_id=DS_OIDC_CLIENT_ID) + + +def retrieve_pdc_data(endpoint: str, *, params: dict[str, Any] | None = None): + oauth = OAuth2Session(client=client) + + token_response = oauth.fetch_token( + token_url=f"{DS_OIDC_BASE_URL}/protocol/openid-connect/token", + client_id=DS_OIDC_CLIENT_ID, + client_secret=DS_OIDC_CLIENT_SECRET, + ) + + access_token = token_response["access_token"] + + pdc_data_response = requests.get( + f"{DS_PDC_API_BASE_URL}{endpoint}", + headers={"Authorization": f"Bearer {access_token}"}, + params=params or {}, + ) + + pdc_data_dict = pdc_data_response.json() + + with open(base_path / f"{endpoint}.json", "w", encoding="utf-8") as f: + json.dump(pdc_data_dict, f, indent=4, ensure_ascii=False) + + +def short_code_match(example, pred, trace=None, frac=1.0): + """ + Used to evaluate whether a predicted short code matches the expected short code. + """ + return example.shortCode in pred.shortCodes + + +def short_code_match_with_feedback( + example, pred, trace=None, pred_name=None, pred_trace=None +): + """ + Provides feedback to the LLM on whether the predicted short code matches the expected short code. + """ + # TODO: give less credit for matches that are lower in the list of predicted short codes + score = 1.0 if example.shortCode in pred.shortCodes else 0.0 + if example.shortCode in pred.shortCodes: + feedback = f"You correctly returned the short code as `{example.shortCode}`. The correct short code was indeed `{example.shortCode}`." + else: + feedback = f"You incorrectly returned the short code in `{', '.join(pred.shortCodes)}`. The correct answer is `{example.shortCode}`. Think about how you could have reasoned to get the correct short code." + return dspy.Prediction(score=score, feedback=feedback) + + +class FormField(pydantic.BaseModel): + """ + A representation of a form field from an application form. + """ + + position: int = pydantic.Field( + description="The position of the form field in the document" + ) + label: str = pydantic.Field(description="A form field label") + instructions: str | None = pydantic.Field( + description="Instructions for the form field" + ) + + +class FormFields(dspy.Signature): + """ + Please extract all form fields and their instructions from the form document provided below. + """ + + form: str = dspy.InputField(desc="The form document content in markdown format") + form_fields: list[FormField] = dspy.OutputField( + desc="A list of form fields with their details" + ) + + +class MatchedShortCodes(dspy.Signature): + # TODO: prompt to order short codes by likelihood and output a confidence score for each + """Which short codes best corresponds to the given label? Only respond with 1-3 short codes from the provided list. If none are a good fit, respond with 'unknown'.""" + + context = dspy.InputField(desc="list of available short codes") + label = dspy.InputField(desc="the label to match") + shortCodes: list[str] = dspy.OutputField( + desc="Up to 3 short code that are likely matches for the label" + ) + + +class RAG(dspy.Module): + """ + A retrieval-augmented generation DSPy module to match form field labels to base field short codes. + """ + + def __init__(self): + self.embedding_model = SentenceTransformer( + embedding_model_name, + ) + + self.embedder = dspy.Embedder(self.embedding_model.encode, caching=False) + self.respond = dspy.Predict(MatchedShortCodes) + base_fields_df = pl.read_json(base_path / "baseFields.json") + self.corpus = [d["shortCode"] for d in base_fields_df.to_dicts()] + self.search = dspy.retrievers.Embeddings( + embedder=self.embedder, corpus=self.corpus, k=topk_docs_to_retrieve + ) + + def forward(self, label): + context = self.search(label).passages + return self.respond( + context=context, + label=label, + instructions="Which short code best corresponds to the given label? Only respond with a short code from the provided list. If none are a good fit, respond with 'unknown'.", + ) + + + +@app.command() +def retrieve_application_forms(): + retrieve_pdc_data( + "applicationForms", + params={"_count": 1000, "_page": 1}, + ) + + +@app.command() +def retrieve_base_fields(): + retrieve_pdc_data("baseFields") + + +@app.command() +def transform_application_forms(): + with open(base_path / "applicationForms.json", "r", encoding="utf-8") as f: + application_form_dict = json.load(f) + + mappings = [] + + for form in application_form_dict["entries"]: + for field in form["fields"]: + if "baseField" in field: + mappings.append( + { + "label": field["label"], + # "position": field.get("position", None), + "instructions": field.get("instructions", None), + "short_code": field.get("baseFieldShortCode", None), + "base_field_label": field["baseField"]["label"], + } + ) + + mappings_df = pl.from_dicts(mappings).unique() + + mappings_df["label"].count() + + non_matching_mappings_df = mappings_df.filter( + pl.col("base_field_label") != pl.col("label") + ).select( + [ + "label", + # "position", + "instructions", + "short_code", + ] + ) + non_matching_mappings_df["label"].count() + + non_matching_mappings_df.write_csv(base_path / "non-matching-mappings.csv") + + +@app.command() +def extract_form_fields(source: str): + # TODO: make it so we can extract from multiple documents and deal with long document truncation + extract = dspy.Predict(FormFields) + + converter = DocumentConverter() + + result = converter.convert(source) + form = result.document.export_to_markdown() + + response = extract(form=form) + + extracted_fields = response.form_fields + + extracted_fields_df = pl.from_dicts( + [f.model_dump() for f in extracted_fields] + ).select("label", "instructions") # , "position") + + extracted_fields_df.write_csv(base_path / "extracted-form-fields.csv") + + +@app.command() +def combine_mappings(): + # TODO: make more flexible to permit arbitrary sources of mappings, make hand collected mappings optional + extracted_fields_df = pl.read_csv(base_path / "extracted-form-fields.csv") + non_matching_mappings_df = pl.read_csv(base_path / "non-matching-mappings.csv") + hand_collected_mappings_df = pl.read_csv( + base_path / "hand-collected-mappings.csv" + ).select(["label", "instructions"]) + + combined_mappings_df = pl.concat( + [non_matching_mappings_df, extracted_fields_df, hand_collected_mappings_df], + how="diagonal", + ) + + combined_mappings_df.write_csv(base_path / "combined-mappings.csv") + + +@app.command() +def generate_candidates(): + rag = RAG() + if os.path.exists(optimized_rag_path): + # Load optimized RAG if available + rag.load(optimized_rag_path) + + mappings_df = pl.read_csv(base_path / "combined-mappings.csv") + + mappings_with_possible_short_codes = [ + {**mapping, "possible_short_code": rag(label=mapping["label"]).shortCodes[0]} + for mapping in mappings_df.to_dicts() + ] + + pl.from_dicts(mappings_with_possible_short_codes).write_csv( + base_path / "mappings-with-possible-short-codes.csv" + ) + + +@app.command() +def split_datasets(): + mappings_df = pl.read_csv(base_path / "edited-mappings.csv").filter( + pl.col("short_code").is_not_null() + ) + + data = mappings_df.to_dicts() + + # TODO: implement more dynamic train/val/test split + trainset, testset = data[:39], data[39:122] + + pl.from_dicts(trainset).write_csv(base_path / "trainset.csv") + pl.from_dicts(testset).write_csv(base_path / "testset.csv") + + +@app.command() +def evaluate_rag(): + rag = RAG() + if os.path.exists(optimized_rag_path): + # Load optimized RAG if available + # TODO: make it possible to evaluate either base and optimized RAGs with a flag + rag.load(optimized_rag_path) + + testset = pl.read_csv(base_path / "testset.csv") + + testset_examples = [ + dspy.Example(**{**d, "shortCode": d["short_code"]}).with_inputs("label") + for d in testset.to_dicts() + ] + + # Evaluation setup + evaluate = dspy.Evaluate( + devset=testset_examples, + metric=short_code_match, + num_threads=2, + display_progress=False, + display_table=2, + ) + + evaluation = evaluate(rag) + + print(evaluation) + + +@app.command() +def optimize_rag(): + optimizer = dspy.GEPA( + metric=short_code_match_with_feedback, + auto="heavy", + num_threads=2, + track_stats=True, + use_merge=False, + reflection_lm=lm, # TODO: use a separate reflection LLM + ) + + rag = RAG() + if os.path.exists(optimized_rag_path): + # Load optimized RAG if available + rag.load(optimized_rag_path) + + trainset = pl.read_csv(base_path / "trainset.csv") + + trainset_examples = [ + dspy.Example(**{**d, "shortCode": d["short_code"]}).with_inputs("label") + for d in trainset.to_dicts() + ] + + optimized_rag = optimizer.compile( + rag, + trainset=trainset_examples, + valset=trainset_examples, # TODO: use a separate validation set + ) + + optimized_rag.save(base_path / "optimized_rag.json") + + +if __name__ == "__main__": + app()