diff --git a/sample_registry/app.py b/sample_registry/app.py index e796b22..dff21a3 100644 --- a/sample_registry/app.py +++ b/sample_registry/app.py @@ -1,9 +1,9 @@ -import csv -import pickle -import os -from collections import defaultdict -from datetime import datetime -from flask import ( +import csv +import pickle +import os +from collections import defaultdict +from datetime import datetime +from flask import ( Flask, make_response, render_template, @@ -15,14 +15,19 @@ jsonify, ) from flask_sqlalchemy import SQLAlchemy -from io import StringIO -from pathlib import Path -from sample_registry import ARCHIVE_ROOT, SQLALCHEMY_DATABASE_URI -from sample_registry.models import Base, Annotation, Run, Sample -from sample_registry.db import run_to_dataframe, query_tag_stats, STANDARD_TAGS -from sample_registry.standards import STANDARD_HOST_SPECIES, STANDARD_SAMPLE_TYPES -from typing import Optional -from werkzeug.middleware.proxy_fix import ProxyFix +from contextlib import contextmanager +from io import StringIO +from pathlib import Path +from sample_registry import ARCHIVE_ROOT, SQLALCHEMY_DATABASE_URI +from sample_registry.mapping import SampleTable +from sample_registry.models import Base, Annotation, Run, Sample +from sample_registry.db import run_to_dataframe, query_tag_stats, STANDARD_TAGS +from sample_registry.registrar import SampleRegistry +from sample_registry.standards import STANDARD_HOST_SPECIES, STANDARD_SAMPLE_TYPES +from typing import Optional +from werkzeug.middleware.proxy_fix import ProxyFix +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker app = Flask(__name__) app.secret_key = os.urandom(12) @@ -31,14 +36,53 @@ # whatever production server you are using instead. It's ok to leave this in when running the dev server. app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) -# Sanitize and RO db connection -SQLALCHEMY_DATABASE_URI = f"{SQLALCHEMY_DATABASE_URI.split('?')[0]}?mode=ro" -app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI -print(SQLALCHEMY_DATABASE_URI) -# Ensure SQLite explicitly opens in read-only mode -app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"connect_args": {"uri": True}} -db = SQLAlchemy(model_class=Base) -db.init_app(app) +# Sanitize and RO db connection +SQLALCHEMY_WRITE_URI = SQLALCHEMY_DATABASE_URI +SQLALCHEMY_DATABASE_URI = f"{SQLALCHEMY_DATABASE_URI.split('?')[0]}?mode=ro" +app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI +print(SQLALCHEMY_DATABASE_URI) +# Ensure SQLite explicitly opens in read-only mode +app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"connect_args": {"uri": True}} +db = SQLAlchemy(model_class=Base) +db.init_app(app) +write_engine = create_engine(SQLALCHEMY_WRITE_URI, echo=False) +WriteSession = sessionmaker(bind=write_engine) + + +@contextmanager +def api_registry(): + session = WriteSession() + registry = SampleRegistry(session=session) + try: + yield registry + finally: + session.close() + registry.session.close() + + +def api_request_data(): + data = request.get_json(silent=True) + if data is None: + data = request.form.to_dict() + return data or {} + + +def api_error(message: str, status: int = 400): + return jsonify({"status": "error", "error": message}), status + + +def api_sample_table_from_request(): + if "sample_table" in request.files: + content = request.files["sample_table"].stream.read().decode("utf-8") + else: + data = api_request_data() + content = data.get("sample_table") + if not content: + raise ValueError("sample_table is required") + sample_table = SampleTable.load(StringIO(content)) + sample_table.look_up_nextera_barcodes() + sample_table.validate() + return sample_table @app.route("/favicon.ico") @@ -306,8 +350,8 @@ def show_stats(): ) -@app.route("/download/", methods=["GET", "POST"]) -def download(run_acc: str): +@app.route("/download/", methods=["GET", "POST"]) +def download(run_acc: str): ext = run_acc[-4:] run_acc = run_acc[:-4] t = run_to_dataframe(db, run_acc) @@ -344,9 +388,193 @@ def download(run_acc: str): # Create the response and set the appropriate headers response = make_response(csv_file.getvalue()) - response.headers["Content-Disposition"] = f"attachment; filename={run_acc}{ext}" - response.headers["Content-type"] = "text/csv" - return response + response.headers["Content-Disposition"] = f"attachment; filename={run_acc}{ext}" + response.headers["Content-type"] = "text/csv" + return response + + +@app.post("/api/register_run") +def api_register_run(): + data = api_request_data() + missing = [k for k in ("file", "date", "comment") if not data.get(k)] + if missing: + return api_error(f"Missing required fields: {', '.join(missing)}") + try: + lane = int(data.get("lane", 1)) + except ValueError as exc: + return api_error(f"Invalid lane value: {exc}") + with api_registry() as registry: + try: + run_accession = registry.register_run( + data["date"], + data.get("type", "Illumina-MiSeq"), + "Nextera XT", + lane, + data["file"], + data["comment"], + ) + registry.session.commit() + except Exception: + registry.session.rollback() + raise + return jsonify({"status": "ok", "run_accession": run_accession}) + + +def api_register_sample_annotations(register_samples: bool): + data = api_request_data() + if not data.get("run_accession"): + return api_error("Missing required field: run_accession") + try: + run_accession = int(data["run_accession"]) + except ValueError as exc: + return api_error(f"Invalid run_accession value: {exc}") + try: + sample_table = api_sample_table_from_request() + except Exception as exc: + return api_error(str(exc)) + with api_registry() as registry: + try: + if register_samples: + registry.check_samples(run_accession, exists=False) + registry.check_run_accession(run_accession) + if register_samples: + registry.register_samples(run_accession, sample_table) + registry.register_annotations(run_accession, sample_table) + registry.session.commit() + except Exception: + registry.session.rollback() + raise + return jsonify( + { + "status": "ok", + "run_accession": run_accession, + "sample_count": len(sample_table.recs), + } + ) + + +@app.post("/api/register_samples") +def api_register_samples(): + return api_register_sample_annotations(register_samples=True) + + +@app.post("/api/register_annotations") +def api_register_annotations(): + return api_register_sample_annotations(register_samples=False) + + +@app.post("/api/unregister_samples") +def api_unregister_samples(): + data = api_request_data() + if not data.get("run_accession"): + return api_error("Missing required field: run_accession") + try: + run_accession = int(data["run_accession"]) + except ValueError as exc: + return api_error(f"Invalid run_accession value: {exc}") + with api_registry() as registry: + try: + registry.check_run_accession(run_accession) + samples_removed = registry.remove_samples(run_accession) + registry.session.commit() + except Exception: + registry.session.rollback() + raise + return jsonify( + { + "status": "ok", + "run_accession": run_accession, + "removed_samples": samples_removed, + } + ) + + +@app.post("/api/modify_run") +def api_modify_run(): + data = api_request_data() + if not data.get("run_accession"): + return api_error("Missing required field: run_accession") + try: + run_accession = int(data["run_accession"]) + except ValueError as exc: + return api_error(f"Invalid run_accession value: {exc}") + lane = data.get("lane") + if lane is not None: + try: + lane = int(lane) + except ValueError as exc: + return api_error(f"Invalid lane value: {exc}") + with api_registry() as registry: + try: + registry.check_run_accession(run_accession) + registry.modify_run( + run_accession=run_accession, + run_date=data.get("date"), + machine_type=data.get("type"), + machine_kit=data.get("kit"), + lane=lane, + data_uri=data.get("data_uri"), + comment=data.get("comment"), + admin_comment=data.get("admin_comment"), + ) + registry.session.commit() + except Exception: + registry.session.rollback() + raise + return jsonify({"status": "ok", "run_accession": run_accession}) + + +@app.post("/api/modify_sample") +def api_modify_sample(): + data = api_request_data() + if not data.get("sample_accession"): + return api_error("Missing required field: sample_accession") + try: + sample_accession = int(data["sample_accession"]) + except ValueError as exc: + return api_error(f"Invalid sample_accession value: {exc}") + with api_registry() as registry: + try: + registry.check_sample_accession(sample_accession) + registry.modify_sample( + sample_accession=sample_accession, + sample_name=data.get("sample_name"), + sample_type=data.get("sample_type"), + subject_id=data.get("subject_id"), + host_species=data.get("host_species"), + barcode_sequence=data.get("barcode_sequence"), + primer_sequence=data.get("primer_sequence"), + ) + registry.session.commit() + except Exception: + registry.session.rollback() + raise + return jsonify({"status": "ok", "sample_accession": sample_accession}) + + +@app.post("/api/modify_annotation") +def api_modify_annotation(): + data = api_request_data() + missing = [k for k in ("sample_accession", "key", "val") if not data.get(k)] + if missing: + return api_error(f"Missing required fields: {', '.join(missing)}") + try: + sample_accession = int(data["sample_accession"]) + except ValueError as exc: + return api_error(f"Invalid sample_accession value: {exc}") + with api_registry() as registry: + try: + registry.check_sample_accession(sample_accession) + registry.modify_annotation( + sample_accession=sample_accession, + key=data["key"], + val=data["val"], + ) + registry.session.commit() + except Exception: + registry.session.rollback() + raise + return jsonify({"status": "ok", "sample_accession": sample_accession}) @app.route("/description") diff --git a/sample_registry/registrar.py b/sample_registry/registrar.py index 8775fe4..4bd11b5 100644 --- a/sample_registry/registrar.py +++ b/sample_registry/registrar.py @@ -25,11 +25,19 @@ def __init__(self, session: Optional[Session] = None, uri: Optional[str] = None) self.session = imported_session - def check_run_accession(self, acc: int) -> Run: - run = self.session.scalar(select(Run).where(Run.run_accession == acc)) - if not run: - raise ValueError("Run does not exist %s" % acc) - return run + def check_run_accession(self, acc: int) -> Run: + run = self.session.scalar(select(Run).where(Run.run_accession == acc)) + if not run: + raise ValueError("Run does not exist %s" % acc) + return run + + def check_sample_accession(self, acc: int) -> Sample: + sample = self.session.scalar( + select(Sample).where(Sample.sample_accession == acc) + ) + if not sample: + raise ValueError("Sample does not exist %s" % acc) + return sample def get_run(self, run_accession: int) -> Run | None: """Return the ``Run`` record for ``run_accession``. diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..a814f5f --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,265 @@ +import importlib +import io + +import pytest +from sqlalchemy import create_engine, select +from sqlalchemy.orm import sessionmaker + +from sample_registry.mapping import SampleTable +from sample_registry.models import Annotation, Base, Run, Sample + +SAMPLES = [ + { + "SampleID": "abc123", + "BarcodeSequence": "GGGCCT", + "SampleType": "Oral swab", + "bb": "cd e29", + "ll": "mno 1", + }, + { + "SampleID": "def456", + "BarcodeSequence": "TTTCCC", + "SampleType": "Blood", + "bb": "asdf", + }, +] + +MODIFIED_SAMPLES = [ + { + "SampleID": "abc123", + "BarcodeSequence": "GGGCCT", + "SampleType": "Feces", + "fg": "hi5 34", + } +] + + +def _sample_table_payload(records): + table = SampleTable(records) + buf = io.StringIO() + table.write(buf) + return buf.getvalue() + + +@pytest.fixture +def api_client(tmp_path, monkeypatch): + db_path = tmp_path / "registry.sqlite" + uri = f"sqlite:///{db_path}" + monkeypatch.setenv("SAMPLE_REGISTRY_DB_URI", uri) + monkeypatch.delenv("PYTEST_VERSION", raising=False) + engine = create_engine(uri, echo=False) + Base.metadata.create_all(engine) + Session = sessionmaker(bind=engine) + session = Session() + create_test_db = importlib.import_module("sample_registry.db").create_test_db + create_test_db(session) + session.close() + + sample_registry = importlib.import_module("sample_registry") + importlib.reload(sample_registry) + import sample_registry.app as app_module + + app_module = importlib.reload(app_module) + app_module.app.testing = True + return app_module.app.test_client(), Session + + +def test_http_access_root(api_client): + client, _ = api_client + response = client.get("/") + assert response.status_code == 302 + assert response.headers["Location"].endswith("/runs") + + +def test_api_register_run(api_client): + client, Session = api_client + response = client.post( + "/api/register_run", + json={ + "file": "raw/run4.fastq.gz", + "date": "2024-08-01", + "comment": "new run", + "type": "Illumina-MiSeq", + "lane": 1, + }, + ) + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["run_accession"] == 4 + + session = Session() + try: + run = session.scalar(select(Run).where(Run.run_accession == 4)) + assert run.comment == "new run" + assert run.machine_type == "Illumina-MiSeq" + finally: + session.close() + + +def test_api_register_samples(api_client): + client, Session = api_client + client.post( + "/api/register_run", + json={ + "file": "raw/run4.fastq.gz", + "date": "2024-08-01", + "comment": "new run", + }, + ) + payload = { + "run_accession": 4, + "sample_table": _sample_table_payload(SAMPLES), + } + response = client.post("/api/register_samples", json=payload) + assert response.status_code == 200 + assert response.get_json()["sample_count"] == 2 + + session = Session() + try: + samples = session.scalars( + select(Sample) + .where(Sample.run_accession == 4) + .order_by(Sample.sample_accession) + ).all() + assert [s.sample_accession for s in samples] == [6, 7] + assert samples[0].sample_type == "Oral swab" + finally: + session.close() + + +def test_api_register_annotations(api_client): + client, Session = api_client + client.post( + "/api/register_run", + json={ + "file": "raw/run4.fastq.gz", + "date": "2024-08-01", + "comment": "new run", + }, + ) + client.post( + "/api/register_samples", + json={ + "run_accession": 4, + "sample_table": _sample_table_payload(SAMPLES), + }, + ) + response = client.post( + "/api/register_annotations", + json={ + "run_accession": 4, + "sample_table": _sample_table_payload(MODIFIED_SAMPLES), + }, + ) + assert response.status_code == 200 + + session = Session() + try: + sample = session.scalar(select(Sample).where(Sample.sample_accession == 6)) + assert sample.sample_type == "Feces" + annotation = session.scalar( + select(Annotation).where( + Annotation.sample_accession == 6, Annotation.key == "fg" + ) + ) + assert annotation.val == "hi5 34" + finally: + session.close() + + +def test_api_unregister_samples(api_client): + client, Session = api_client + client.post( + "/api/register_run", + json={ + "file": "raw/run4.fastq.gz", + "date": "2024-08-01", + "comment": "new run", + }, + ) + client.post( + "/api/register_samples", + json={ + "run_accession": 4, + "sample_table": _sample_table_payload(SAMPLES), + }, + ) + response = client.post("/api/unregister_samples", json={"run_accession": 4}) + assert response.status_code == 200 + + session = Session() + try: + assert not session.scalar(select(Sample).where(Sample.run_accession == 4)) + finally: + session.close() + + +def test_api_modify_run(api_client): + client, Session = api_client + client.post( + "/api/register_run", + json={ + "file": "raw/run4.fastq.gz", + "date": "2024-08-01", + "comment": "new run", + }, + ) + response = client.post( + "/api/modify_run", + json={ + "run_accession": 4, + "comment": "updated", + "lane": 2, + }, + ) + assert response.status_code == 200 + + session = Session() + try: + run = session.scalar(select(Run).where(Run.run_accession == 4)) + assert run.comment == "updated" + assert run.lane == 2 + finally: + session.close() + + +def test_api_modify_sample(api_client): + client, Session = api_client + response = client.post( + "/api/modify_sample", + json={ + "sample_accession": 1, + "sample_name": "Sample1-updated", + "subject_id": "Subject1a", + }, + ) + assert response.status_code == 200 + + session = Session() + try: + sample = session.scalar(select(Sample).where(Sample.sample_accession == 1)) + assert sample.sample_name == "Sample1-updated" + assert sample.subject_id == "Subject1a" + finally: + session.close() + + +def test_api_modify_annotation(api_client): + client, Session = api_client + response = client.post( + "/api/modify_annotation", + json={"sample_accession": 1, "key": "key0", "val": "updated"}, + ) + assert response.status_code == 200 + + session = Session() + try: + annotation = session.scalar( + select(Annotation).where( + Annotation.sample_accession == 1, Annotation.key == "key0" + ) + ) + assert annotation.val == "updated" + finally: + session.close()