From 120fcf995a2780f6cf535ad65dd85093ba619a87 Mon Sep 17 00:00:00 2001 From: Charlie Date: Fri, 30 Jan 2026 18:05:37 -0500 Subject: [PATCH 1/4] Add API endpoints for registry CLI --- sample_registry/app.py | 295 ++++++++++++++++++++++++++++++++--- sample_registry/registrar.py | 18 ++- 2 files changed, 284 insertions(+), 29 deletions(-) diff --git a/sample_registry/app.py b/sample_registry/app.py index e796b22..ea86492 100644 --- a/sample_registry/app.py +++ b/sample_registry/app.py @@ -1,6 +1,7 @@ -import csv -import pickle -import os +import csv +import gzip +import pickle +import os from collections import defaultdict from datetime import datetime from flask import ( @@ -15,14 +16,18 @@ jsonify, ) from flask_sqlalchemy import SQLAlchemy -from io import StringIO -from pathlib import Path -from sample_registry import ARCHIVE_ROOT, SQLALCHEMY_DATABASE_URI -from sample_registry.models import Base, Annotation, Run, Sample -from sample_registry.db import run_to_dataframe, query_tag_stats, STANDARD_TAGS -from sample_registry.standards import STANDARD_HOST_SPECIES, STANDARD_SAMPLE_TYPES -from typing import Optional -from werkzeug.middleware.proxy_fix import ProxyFix +from contextlib import contextmanager +from io import StringIO +from pathlib import Path +from sample_registry import ARCHIVE_ROOT, SQLALCHEMY_DATABASE_URI +from sample_registry.mapping import SampleTable +from sample_registry.models import Base, Annotation, Run, Sample +from sample_registry.db import run_to_dataframe, query_tag_stats, STANDARD_TAGS +from sample_registry.registrar import SampleRegistry +from sample_registry.standards import STANDARD_HOST_SPECIES, STANDARD_SAMPLE_TYPES +from seqBackupLib.illumina import IlluminaFastq +from typing import Optional +from werkzeug.middleware.proxy_fix import ProxyFix app = Flask(__name__) app.secret_key = os.urandom(12) @@ -31,14 +36,49 @@ # whatever production server you are using instead. It's ok to leave this in when running the dev server. app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) -# Sanitize and RO db connection -SQLALCHEMY_DATABASE_URI = f"{SQLALCHEMY_DATABASE_URI.split('?')[0]}?mode=ro" -app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI -print(SQLALCHEMY_DATABASE_URI) -# Ensure SQLite explicitly opens in read-only mode -app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"connect_args": {"uri": True}} -db = SQLAlchemy(model_class=Base) -db.init_app(app) +# Sanitize and RO db connection +SQLALCHEMY_WRITE_URI = SQLALCHEMY_DATABASE_URI +SQLALCHEMY_DATABASE_URI = f"{SQLALCHEMY_DATABASE_URI.split('?')[0]}?mode=ro" +app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI +print(SQLALCHEMY_DATABASE_URI) +# Ensure SQLite explicitly opens in read-only mode +app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"connect_args": {"uri": True}} +db = SQLAlchemy(model_class=Base) +db.init_app(app) + + +@contextmanager +def api_registry(): + registry = SampleRegistry(uri=SQLALCHEMY_WRITE_URI) + try: + yield registry + finally: + registry.session.close() + + +def api_request_data(): + data = request.get_json(silent=True) + if data is None: + data = request.form.to_dict() + return data or {} + + +def api_error(message: str, status: int = 400): + return jsonify({"status": "error", "error": message}), status + + +def api_sample_table_from_request(): + if "sample_table" in request.files: + content = request.files["sample_table"].stream.read().decode("utf-8") + else: + data = api_request_data() + content = data.get("sample_table") + if not content: + raise ValueError("sample_table is required") + sample_table = SampleTable.load(StringIO(content)) + sample_table.look_up_nextera_barcodes() + sample_table.validate() + return sample_table @app.route("/favicon.ico") @@ -306,8 +346,8 @@ def show_stats(): ) -@app.route("/download/", methods=["GET", "POST"]) -def download(run_acc: str): +@app.route("/download/", methods=["GET", "POST"]) +def download(run_acc: str): ext = run_acc[-4:] run_acc = run_acc[:-4] t = run_to_dataframe(db, run_acc) @@ -344,9 +384,216 @@ def download(run_acc: str): # Create the response and set the appropriate headers response = make_response(csv_file.getvalue()) - response.headers["Content-Disposition"] = f"attachment; filename={run_acc}{ext}" - response.headers["Content-type"] = "text/csv" - return response + response.headers["Content-Disposition"] = f"attachment; filename={run_acc}{ext}" + response.headers["Content-type"] = "text/csv" + return response + + +@app.post("/api/register_run") +def api_register_run(): + data = api_request_data() + missing = [k for k in ("file", "date", "comment") if not data.get(k)] + if missing: + return api_error(f"Missing required fields: {', '.join(missing)}") + try: + lane = int(data.get("lane", 1)) + except ValueError as exc: + return api_error(f"Invalid lane value: {exc}") + with api_registry() as registry: + try: + run_accession = registry.register_run( + data["date"], + data.get("type", "Illumina-MiSeq"), + "Nextera XT", + lane, + data["file"], + data["comment"], + ) + registry.session.commit() + except Exception as exc: + registry.session.rollback() + return api_error(str(exc)) + return jsonify({"status": "ok", "run_accession": run_accession}) + + +@app.post("/api/register_illumina_file") +def api_register_illumina_file(): + data = api_request_data() + missing = [k for k in ("file", "comment") if not data.get(k)] + if missing: + return api_error(f"Missing required fields: {', '.join(missing)}") + try: + f = IlluminaFastq(gzip.open(data["file"], "rt")) + with api_registry() as registry: + run_accession = registry.register_run( + f.folder_info["date"], + f.machine_type, + "Nextera XT", + f.lane, + str(f.filepath), + data["comment"], + ) + registry.session.commit() + except Exception as exc: + return api_error(str(exc)) + return jsonify({"status": "ok", "run_accession": run_accession}) + + +def api_register_sample_annotations(register_samples: bool): + data = api_request_data() + if not data.get("run_accession"): + return api_error("Missing required field: run_accession") + try: + run_accession = int(data["run_accession"]) + except ValueError as exc: + return api_error(f"Invalid run_accession value: {exc}") + try: + sample_table = api_sample_table_from_request() + except Exception as exc: + return api_error(str(exc)) + with api_registry() as registry: + try: + if register_samples: + registry.check_samples(run_accession, exists=False) + registry.check_run_accession(run_accession) + if register_samples: + registry.register_samples(run_accession, sample_table) + registry.register_annotations(run_accession, sample_table) + registry.session.commit() + except Exception as exc: + registry.session.rollback() + return api_error(str(exc)) + return jsonify( + { + "status": "ok", + "run_accession": run_accession, + "sample_count": len(sample_table.recs), + } + ) + + +@app.post("/api/register_samples") +def api_register_samples(): + return api_register_sample_annotations(register_samples=True) + + +@app.post("/api/register_annotations") +def api_register_annotations(): + return api_register_sample_annotations(register_samples=False) + + +@app.post("/api/unregister_samples") +def api_unregister_samples(): + data = api_request_data() + if not data.get("run_accession"): + return api_error("Missing required field: run_accession") + try: + run_accession = int(data["run_accession"]) + except ValueError as exc: + return api_error(f"Invalid run_accession value: {exc}") + with api_registry() as registry: + try: + registry.check_run_accession(run_accession) + samples_removed = registry.remove_samples(run_accession) + registry.session.commit() + except Exception as exc: + registry.session.rollback() + return api_error(str(exc)) + return jsonify( + { + "status": "ok", + "run_accession": run_accession, + "removed_samples": samples_removed, + } + ) + + +@app.post("/api/modify_run") +def api_modify_run(): + data = api_request_data() + if not data.get("run_accession"): + return api_error("Missing required field: run_accession") + try: + run_accession = int(data["run_accession"]) + except ValueError as exc: + return api_error(f"Invalid run_accession value: {exc}") + lane = data.get("lane") + if lane is not None: + try: + lane = int(lane) + except ValueError as exc: + return api_error(f"Invalid lane value: {exc}") + with api_registry() as registry: + try: + registry.check_run_accession(run_accession) + registry.modify_run( + run_accession=run_accession, + run_date=data.get("date"), + machine_type=data.get("type"), + machine_kit=data.get("kit"), + lane=lane, + data_uri=data.get("data_uri"), + comment=data.get("comment"), + admin_comment=data.get("admin_comment"), + ) + registry.session.commit() + except Exception as exc: + registry.session.rollback() + return api_error(str(exc)) + return jsonify({"status": "ok", "run_accession": run_accession}) + + +@app.post("/api/modify_sample") +def api_modify_sample(): + data = api_request_data() + if not data.get("sample_accession"): + return api_error("Missing required field: sample_accession") + try: + sample_accession = int(data["sample_accession"]) + except ValueError as exc: + return api_error(f"Invalid sample_accession value: {exc}") + with api_registry() as registry: + try: + registry.check_sample_accession(sample_accession) + registry.modify_sample( + sample_accession=sample_accession, + sample_name=data.get("sample_name"), + sample_type=data.get("sample_type"), + subject_id=data.get("subject_id"), + host_species=data.get("host_species"), + barcode_sequence=data.get("barcode_sequence"), + primer_sequence=data.get("primer_sequence"), + ) + registry.session.commit() + except Exception as exc: + registry.session.rollback() + return api_error(str(exc)) + return jsonify({"status": "ok", "sample_accession": sample_accession}) + + +@app.post("/api/modify_annotation") +def api_modify_annotation(): + data = api_request_data() + missing = [k for k in ("sample_accession", "key", "val") if not data.get(k)] + if missing: + return api_error(f"Missing required fields: {', '.join(missing)}") + try: + sample_accession = int(data["sample_accession"]) + except ValueError as exc: + return api_error(f"Invalid sample_accession value: {exc}") + with api_registry() as registry: + try: + registry.check_sample_accession(sample_accession) + registry.modify_annotation( + sample_accession=sample_accession, + key=data["key"], + val=data["val"], + ) + registry.session.commit() + except Exception as exc: + registry.session.rollback() + return api_error(str(exc)) + return jsonify({"status": "ok", "sample_accession": sample_accession}) @app.route("/description") diff --git a/sample_registry/registrar.py b/sample_registry/registrar.py index 8775fe4..4bd11b5 100644 --- a/sample_registry/registrar.py +++ b/sample_registry/registrar.py @@ -25,11 +25,19 @@ def __init__(self, session: Optional[Session] = None, uri: Optional[str] = None) self.session = imported_session - def check_run_accession(self, acc: int) -> Run: - run = self.session.scalar(select(Run).where(Run.run_accession == acc)) - if not run: - raise ValueError("Run does not exist %s" % acc) - return run + def check_run_accession(self, acc: int) -> Run: + run = self.session.scalar(select(Run).where(Run.run_accession == acc)) + if not run: + raise ValueError("Run does not exist %s" % acc) + return run + + def check_sample_accession(self, acc: int) -> Sample: + sample = self.session.scalar( + select(Sample).where(Sample.sample_accession == acc) + ) + if not sample: + raise ValueError("Sample does not exist %s" % acc) + return sample def get_run(self, run_accession: int) -> Run | None: """Return the ``Run`` record for ``run_accession``. From 11daf3b3bdd5ca5d268487bb841c89f5b7d17914 Mon Sep 17 00:00:00 2001 From: Charlie Date: Sat, 31 Jan 2026 10:21:39 -0500 Subject: [PATCH 2/4] Refine API registry sessions and routes --- sample_registry/app.py | 63 +++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 41 deletions(-) diff --git a/sample_registry/app.py b/sample_registry/app.py index ea86492..dff21a3 100644 --- a/sample_registry/app.py +++ b/sample_registry/app.py @@ -1,10 +1,9 @@ import csv -import gzip import pickle import os -from collections import defaultdict -from datetime import datetime -from flask import ( +from collections import defaultdict +from datetime import datetime +from flask import ( Flask, make_response, render_template, @@ -25,9 +24,10 @@ from sample_registry.db import run_to_dataframe, query_tag_stats, STANDARD_TAGS from sample_registry.registrar import SampleRegistry from sample_registry.standards import STANDARD_HOST_SPECIES, STANDARD_SAMPLE_TYPES -from seqBackupLib.illumina import IlluminaFastq from typing import Optional from werkzeug.middleware.proxy_fix import ProxyFix +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker app = Flask(__name__) app.secret_key = os.urandom(12) @@ -45,14 +45,18 @@ app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"connect_args": {"uri": True}} db = SQLAlchemy(model_class=Base) db.init_app(app) +write_engine = create_engine(SQLALCHEMY_WRITE_URI, echo=False) +WriteSession = sessionmaker(bind=write_engine) @contextmanager def api_registry(): - registry = SampleRegistry(uri=SQLALCHEMY_WRITE_URI) + session = WriteSession() + registry = SampleRegistry(session=session) try: yield registry finally: + session.close() registry.session.close() @@ -410,32 +414,9 @@ def api_register_run(): data["comment"], ) registry.session.commit() - except Exception as exc: + except Exception: registry.session.rollback() - return api_error(str(exc)) - return jsonify({"status": "ok", "run_accession": run_accession}) - - -@app.post("/api/register_illumina_file") -def api_register_illumina_file(): - data = api_request_data() - missing = [k for k in ("file", "comment") if not data.get(k)] - if missing: - return api_error(f"Missing required fields: {', '.join(missing)}") - try: - f = IlluminaFastq(gzip.open(data["file"], "rt")) - with api_registry() as registry: - run_accession = registry.register_run( - f.folder_info["date"], - f.machine_type, - "Nextera XT", - f.lane, - str(f.filepath), - data["comment"], - ) - registry.session.commit() - except Exception as exc: - return api_error(str(exc)) + raise return jsonify({"status": "ok", "run_accession": run_accession}) @@ -460,9 +441,9 @@ def api_register_sample_annotations(register_samples: bool): registry.register_samples(run_accession, sample_table) registry.register_annotations(run_accession, sample_table) registry.session.commit() - except Exception as exc: + except Exception: registry.session.rollback() - return api_error(str(exc)) + raise return jsonify( { "status": "ok", @@ -496,9 +477,9 @@ def api_unregister_samples(): registry.check_run_accession(run_accession) samples_removed = registry.remove_samples(run_accession) registry.session.commit() - except Exception as exc: + except Exception: registry.session.rollback() - return api_error(str(exc)) + raise return jsonify( { "status": "ok", @@ -537,9 +518,9 @@ def api_modify_run(): admin_comment=data.get("admin_comment"), ) registry.session.commit() - except Exception as exc: + except Exception: registry.session.rollback() - return api_error(str(exc)) + raise return jsonify({"status": "ok", "run_accession": run_accession}) @@ -565,9 +546,9 @@ def api_modify_sample(): primer_sequence=data.get("primer_sequence"), ) registry.session.commit() - except Exception as exc: + except Exception: registry.session.rollback() - return api_error(str(exc)) + raise return jsonify({"status": "ok", "sample_accession": sample_accession}) @@ -590,9 +571,9 @@ def api_modify_annotation(): val=data["val"], ) registry.session.commit() - except Exception as exc: + except Exception: registry.session.rollback() - return api_error(str(exc)) + raise return jsonify({"status": "ok", "sample_accession": sample_accession}) From 4301829532dfb386ce07658dc967ba763cec645f Mon Sep 17 00:00:00 2001 From: Charlie Date: Sat, 31 Jan 2026 10:29:47 -0500 Subject: [PATCH 3/4] Add API endpoint tests --- tests/test_api.py | 264 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 264 insertions(+) create mode 100644 tests/test_api.py diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..0343647 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,264 @@ +import importlib +import io + +import pytest +from sqlalchemy import create_engine, select +from sqlalchemy.orm import sessionmaker + +from sample_registry.db import create_test_db +from sample_registry.mapping import SampleTable +from sample_registry.models import Annotation, Base, Run, Sample + + +SAMPLES = [ + { + "SampleID": "abc123", + "BarcodeSequence": "GGGCCT", + "SampleType": "Oral swab", + "bb": "cd e29", + "ll": "mno 1", + }, + { + "SampleID": "def456", + "BarcodeSequence": "TTTCCC", + "SampleType": "Blood", + "bb": "asdf", + }, +] + +MODIFIED_SAMPLES = [ + { + "SampleID": "abc123", + "BarcodeSequence": "GGGCCT", + "SampleType": "Feces", + "fg": "hi5 34", + } +] + + +def _sample_table_payload(records): + table = SampleTable(records) + buf = io.StringIO() + table.write(buf) + return buf.getvalue() + + +@pytest.fixture +def api_client(tmp_path, monkeypatch): + db_path = tmp_path / "registry.sqlite" + uri = f"sqlite:///{db_path}" + engine = create_engine(uri, echo=False) + Base.metadata.create_all(engine) + Session = sessionmaker(bind=engine) + session = Session() + create_test_db(session) + session.close() + + monkeypatch.setenv("SAMPLE_REGISTRY_DB_URI", uri) + import sample_registry + + importlib.reload(sample_registry) + import sample_registry.app as app_module + + app_module = importlib.reload(app_module) + app_module.app.testing = True + return app_module.app.test_client(), Session + + +def test_http_access_root(api_client): + client, _ = api_client + response = client.get("/") + assert response.status_code == 302 + assert response.headers["Location"].endswith("/runs") + + +def test_api_register_run(api_client): + client, Session = api_client + response = client.post( + "/api/register_run", + json={ + "file": "raw/run4.fastq.gz", + "date": "2024-08-01", + "comment": "new run", + "type": "Illumina-MiSeq", + "lane": 1, + }, + ) + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["run_accession"] == 4 + + session = Session() + try: + run = session.scalar(select(Run).where(Run.run_accession == 4)) + assert run.comment == "new run" + assert run.machine_type == "Illumina-MiSeq" + finally: + session.close() + + +def test_api_register_samples(api_client): + client, Session = api_client + client.post( + "/api/register_run", + json={ + "file": "raw/run4.fastq.gz", + "date": "2024-08-01", + "comment": "new run", + }, + ) + payload = { + "run_accession": 4, + "sample_table": _sample_table_payload(SAMPLES), + } + response = client.post("/api/register_samples", json=payload) + assert response.status_code == 200 + assert response.get_json()["sample_count"] == 2 + + session = Session() + try: + samples = session.scalars( + select(Sample).where(Sample.run_accession == 4).order_by(Sample.sample_accession) + ).all() + assert [s.sample_accession for s in samples] == [6, 7] + assert samples[0].sample_type == "Oral swab" + finally: + session.close() + + +def test_api_register_annotations(api_client): + client, Session = api_client + client.post( + "/api/register_run", + json={ + "file": "raw/run4.fastq.gz", + "date": "2024-08-01", + "comment": "new run", + }, + ) + client.post( + "/api/register_samples", + json={ + "run_accession": 4, + "sample_table": _sample_table_payload(SAMPLES), + }, + ) + response = client.post( + "/api/register_annotations", + json={ + "run_accession": 4, + "sample_table": _sample_table_payload(MODIFIED_SAMPLES), + }, + ) + assert response.status_code == 200 + + session = Session() + try: + sample = session.scalar(select(Sample).where(Sample.sample_accession == 6)) + assert sample.sample_type == "Feces" + annotation = session.scalar( + select(Annotation).where( + Annotation.sample_accession == 6, Annotation.key == "fg" + ) + ) + assert annotation.val == "hi5 34" + finally: + session.close() + + +def test_api_unregister_samples(api_client): + client, Session = api_client + client.post( + "/api/register_run", + json={ + "file": "raw/run4.fastq.gz", + "date": "2024-08-01", + "comment": "new run", + }, + ) + client.post( + "/api/register_samples", + json={ + "run_accession": 4, + "sample_table": _sample_table_payload(SAMPLES), + }, + ) + response = client.post("/api/unregister_samples", json={"run_accession": 4}) + assert response.status_code == 200 + + session = Session() + try: + assert not session.scalar(select(Sample).where(Sample.run_accession == 4)) + finally: + session.close() + + +def test_api_modify_run(api_client): + client, Session = api_client + client.post( + "/api/register_run", + json={ + "file": "raw/run4.fastq.gz", + "date": "2024-08-01", + "comment": "new run", + }, + ) + response = client.post( + "/api/modify_run", + json={ + "run_accession": 4, + "comment": "updated", + "lane": 2, + }, + ) + assert response.status_code == 200 + + session = Session() + try: + run = session.scalar(select(Run).where(Run.run_accession == 4)) + assert run.comment == "updated" + assert run.lane == 2 + finally: + session.close() + + +def test_api_modify_sample(api_client): + client, Session = api_client + response = client.post( + "/api/modify_sample", + json={ + "sample_accession": 1, + "sample_name": "Sample1-updated", + "subject_id": "Subject1a", + }, + ) + assert response.status_code == 200 + + session = Session() + try: + sample = session.scalar(select(Sample).where(Sample.sample_accession == 1)) + assert sample.sample_name == "Sample1-updated" + assert sample.subject_id == "Subject1a" + finally: + session.close() + + +def test_api_modify_annotation(api_client): + client, Session = api_client + response = client.post( + "/api/modify_annotation", + json={"sample_accession": 1, "key": "key0", "val": "updated"}, + ) + assert response.status_code == 200 + + session = Session() + try: + annotation = session.scalar( + select(Annotation).where( + Annotation.sample_accession == 1, Annotation.key == "key0" + ) + ) + assert annotation.val == "updated" + finally: + session.close() From 47bf7a4b09e9bc552bce83117a37f699537a5c95 Mon Sep 17 00:00:00 2001 From: Charlie Date: Sat, 31 Jan 2026 11:05:43 -0500 Subject: [PATCH 4/4] Stabilize API tests for pytest environment --- tests/test_api.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_api.py b/tests/test_api.py index 0343647..a814f5f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -5,11 +5,9 @@ from sqlalchemy import create_engine, select from sqlalchemy.orm import sessionmaker -from sample_registry.db import create_test_db from sample_registry.mapping import SampleTable from sample_registry.models import Annotation, Base, Run, Sample - SAMPLES = [ { "SampleID": "abc123", @@ -47,16 +45,17 @@ def _sample_table_payload(records): def api_client(tmp_path, monkeypatch): db_path = tmp_path / "registry.sqlite" uri = f"sqlite:///{db_path}" + monkeypatch.setenv("SAMPLE_REGISTRY_DB_URI", uri) + monkeypatch.delenv("PYTEST_VERSION", raising=False) engine = create_engine(uri, echo=False) Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() + create_test_db = importlib.import_module("sample_registry.db").create_test_db create_test_db(session) session.close() - monkeypatch.setenv("SAMPLE_REGISTRY_DB_URI", uri) - import sample_registry - + sample_registry = importlib.import_module("sample_registry") importlib.reload(sample_registry) import sample_registry.app as app_module @@ -119,7 +118,9 @@ def test_api_register_samples(api_client): session = Session() try: samples = session.scalars( - select(Sample).where(Sample.run_accession == 4).order_by(Sample.sample_accession) + select(Sample) + .where(Sample.run_accession == 4) + .order_by(Sample.sample_accession) ).all() assert [s.sample_accession for s in samples] == [6, 7] assert samples[0].sample_type == "Oral swab"