Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 255 additions & 27 deletions sample_registry/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import csv
import pickle
import os
from collections import defaultdict
from datetime import datetime
from flask import (
import csv
import pickle
import os
from collections import defaultdict
from datetime import datetime
from flask import (
Flask,
make_response,
render_template,
Expand All @@ -15,14 +15,19 @@
jsonify,
)
from flask_sqlalchemy import SQLAlchemy
from io import StringIO
from pathlib import Path
from sample_registry import ARCHIVE_ROOT, SQLALCHEMY_DATABASE_URI
from sample_registry.models import Base, Annotation, Run, Sample
from sample_registry.db import run_to_dataframe, query_tag_stats, STANDARD_TAGS
from sample_registry.standards import STANDARD_HOST_SPECIES, STANDARD_SAMPLE_TYPES
from typing import Optional
from werkzeug.middleware.proxy_fix import ProxyFix
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
from sample_registry import ARCHIVE_ROOT, SQLALCHEMY_DATABASE_URI
from sample_registry.mapping import SampleTable
from sample_registry.models import Base, Annotation, Run, Sample
from sample_registry.db import run_to_dataframe, query_tag_stats, STANDARD_TAGS
from sample_registry.registrar import SampleRegistry
from sample_registry.standards import STANDARD_HOST_SPECIES, STANDARD_SAMPLE_TYPES
from typing import Optional
from werkzeug.middleware.proxy_fix import ProxyFix
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

app = Flask(__name__)
app.secret_key = os.urandom(12)
Expand All @@ -31,14 +36,53 @@
# whatever production server you are using instead. It's ok to leave this in when running the dev server.
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1)

# Sanitize and RO db connection
SQLALCHEMY_DATABASE_URI = f"{SQLALCHEMY_DATABASE_URI.split('?')[0]}?mode=ro"
app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI
print(SQLALCHEMY_DATABASE_URI)
# Ensure SQLite explicitly opens in read-only mode
app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"connect_args": {"uri": True}}
db = SQLAlchemy(model_class=Base)
db.init_app(app)
# Sanitize and RO db connection
SQLALCHEMY_WRITE_URI = SQLALCHEMY_DATABASE_URI
SQLALCHEMY_DATABASE_URI = f"{SQLALCHEMY_DATABASE_URI.split('?')[0]}?mode=ro"
app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI
print(SQLALCHEMY_DATABASE_URI)
# Ensure SQLite explicitly opens in read-only mode
app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"connect_args": {"uri": True}}
db = SQLAlchemy(model_class=Base)
db.init_app(app)
write_engine = create_engine(SQLALCHEMY_WRITE_URI, echo=False)
WriteSession = sessionmaker(bind=write_engine)


@contextmanager
def api_registry():
session = WriteSession()
registry = SampleRegistry(session=session)
try:
yield registry
finally:
session.close()
registry.session.close()


def api_request_data():
data = request.get_json(silent=True)
if data is None:
data = request.form.to_dict()
return data or {}


def api_error(message: str, status: int = 400):
return jsonify({"status": "error", "error": message}), status


def api_sample_table_from_request():
if "sample_table" in request.files:
content = request.files["sample_table"].stream.read().decode("utf-8")
else:
data = api_request_data()
content = data.get("sample_table")
if not content:
raise ValueError("sample_table is required")
sample_table = SampleTable.load(StringIO(content))
sample_table.look_up_nextera_barcodes()
sample_table.validate()
return sample_table


@app.route("/favicon.ico")
Expand Down Expand Up @@ -306,8 +350,8 @@ def show_stats():
)


@app.route("/download/<run_acc>", methods=["GET", "POST"])
def download(run_acc: str):
@app.route("/download/<run_acc>", methods=["GET", "POST"])
def download(run_acc: str):
ext = run_acc[-4:]
run_acc = run_acc[:-4]
t = run_to_dataframe(db, run_acc)
Expand Down Expand Up @@ -344,9 +388,193 @@ def download(run_acc: str):

# Create the response and set the appropriate headers
response = make_response(csv_file.getvalue())
response.headers["Content-Disposition"] = f"attachment; filename={run_acc}{ext}"
response.headers["Content-type"] = "text/csv"
return response
response.headers["Content-Disposition"] = f"attachment; filename={run_acc}{ext}"
response.headers["Content-type"] = "text/csv"
return response


@app.post("/api/register_run")
def api_register_run():
data = api_request_data()
missing = [k for k in ("file", "date", "comment") if not data.get(k)]
if missing:
return api_error(f"Missing required fields: {', '.join(missing)}")
try:
lane = int(data.get("lane", 1))
except ValueError as exc:
return api_error(f"Invalid lane value: {exc}")
with api_registry() as registry:
try:
run_accession = registry.register_run(
data["date"],
data.get("type", "Illumina-MiSeq"),
"Nextera XT",
lane,
data["file"],
data["comment"],
)
registry.session.commit()
except Exception:
registry.session.rollback()
raise
return jsonify({"status": "ok", "run_accession": run_accession})


def api_register_sample_annotations(register_samples: bool):
data = api_request_data()
if not data.get("run_accession"):
return api_error("Missing required field: run_accession")
try:
run_accession = int(data["run_accession"])
except ValueError as exc:
return api_error(f"Invalid run_accession value: {exc}")
try:
sample_table = api_sample_table_from_request()
except Exception as exc:
return api_error(str(exc))
with api_registry() as registry:
try:
if register_samples:
registry.check_samples(run_accession, exists=False)
registry.check_run_accession(run_accession)
if register_samples:
registry.register_samples(run_accession, sample_table)
registry.register_annotations(run_accession, sample_table)
registry.session.commit()
except Exception:
registry.session.rollback()
raise
return jsonify(
{
"status": "ok",
"run_accession": run_accession,
"sample_count": len(sample_table.recs),
}
)


@app.post("/api/register_samples")
def api_register_samples():
return api_register_sample_annotations(register_samples=True)


@app.post("/api/register_annotations")
def api_register_annotations():
return api_register_sample_annotations(register_samples=False)


@app.post("/api/unregister_samples")
def api_unregister_samples():
data = api_request_data()
if not data.get("run_accession"):
return api_error("Missing required field: run_accession")
try:
run_accession = int(data["run_accession"])
except ValueError as exc:
return api_error(f"Invalid run_accession value: {exc}")
with api_registry() as registry:
try:
registry.check_run_accession(run_accession)
samples_removed = registry.remove_samples(run_accession)
registry.session.commit()
except Exception:
registry.session.rollback()
raise
return jsonify(
{
"status": "ok",
"run_accession": run_accession,
"removed_samples": samples_removed,
}
)


@app.post("/api/modify_run")
def api_modify_run():
data = api_request_data()
if not data.get("run_accession"):
return api_error("Missing required field: run_accession")
try:
run_accession = int(data["run_accession"])
except ValueError as exc:
return api_error(f"Invalid run_accession value: {exc}")
lane = data.get("lane")
if lane is not None:
try:
lane = int(lane)
except ValueError as exc:
return api_error(f"Invalid lane value: {exc}")
with api_registry() as registry:
try:
registry.check_run_accession(run_accession)
registry.modify_run(
run_accession=run_accession,
run_date=data.get("date"),
machine_type=data.get("type"),
machine_kit=data.get("kit"),
lane=lane,
data_uri=data.get("data_uri"),
comment=data.get("comment"),
admin_comment=data.get("admin_comment"),
)
registry.session.commit()
except Exception:
registry.session.rollback()
raise
return jsonify({"status": "ok", "run_accession": run_accession})


@app.post("/api/modify_sample")
def api_modify_sample():
data = api_request_data()
if not data.get("sample_accession"):
return api_error("Missing required field: sample_accession")
try:
sample_accession = int(data["sample_accession"])
except ValueError as exc:
return api_error(f"Invalid sample_accession value: {exc}")
with api_registry() as registry:
try:
registry.check_sample_accession(sample_accession)
registry.modify_sample(
sample_accession=sample_accession,
sample_name=data.get("sample_name"),
sample_type=data.get("sample_type"),
subject_id=data.get("subject_id"),
host_species=data.get("host_species"),
barcode_sequence=data.get("barcode_sequence"),
primer_sequence=data.get("primer_sequence"),
)
registry.session.commit()
except Exception:
registry.session.rollback()
raise
return jsonify({"status": "ok", "sample_accession": sample_accession})


@app.post("/api/modify_annotation")
def api_modify_annotation():
data = api_request_data()
missing = [k for k in ("sample_accession", "key", "val") if not data.get(k)]
if missing:
return api_error(f"Missing required fields: {', '.join(missing)}")
try:
sample_accession = int(data["sample_accession"])
except ValueError as exc:
return api_error(f"Invalid sample_accession value: {exc}")
with api_registry() as registry:
try:
registry.check_sample_accession(sample_accession)
registry.modify_annotation(
sample_accession=sample_accession,
key=data["key"],
val=data["val"],
)
registry.session.commit()
except Exception:
registry.session.rollback()
raise
return jsonify({"status": "ok", "sample_accession": sample_accession})


@app.route("/description")
Expand Down
18 changes: 13 additions & 5 deletions sample_registry/registrar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,19 @@ def __init__(self, session: Optional[Session] = None, uri: Optional[str] = None)

self.session = imported_session

def check_run_accession(self, acc: int) -> Run:
run = self.session.scalar(select(Run).where(Run.run_accession == acc))
if not run:
raise ValueError("Run does not exist %s" % acc)
return run
def check_run_accession(self, acc: int) -> Run:
run = self.session.scalar(select(Run).where(Run.run_accession == acc))
if not run:
raise ValueError("Run does not exist %s" % acc)
return run

def check_sample_accession(self, acc: int) -> Sample:
sample = self.session.scalar(
select(Sample).where(Sample.sample_accession == acc)
)
if not sample:
raise ValueError("Sample does not exist %s" % acc)
return sample

def get_run(self, run_accession: int) -> Run | None:
"""Return the ``Run`` record for ``run_accession``.
Expand Down
Loading