diff --git a/qmra/risk_assessment/templates/assessment-configurator.html b/qmra/risk_assessment/templates/assessment-configurator.html index 638cbcf..5c9ccda 100644 --- a/qmra/risk_assessment/templates/assessment-configurator.html +++ b/qmra/risk_assessment/templates/assessment-configurator.html @@ -130,6 +130,37 @@

Treatments

{% crispy user_source_form %} +
+
Upload pathogen measurements (CSV)
+
Each file must contain one column with the exact pathogen name.
+
+ +
+ + +
+
+
+
+
+ +
+ + +
+
+
+
+
+ +
+ + +
+
+
+
+
{% crispy user_treatment_form %} @@ -179,6 +210,94 @@

Treatments

form.querySelector("input[type=submit]").addEventListener("click", handleSubmit); form.addEventListener("submit", function(){ return false; }); } + + function pathogenToFieldPrefix(pathogen) { + if (pathogen === "Rotavirus") return "rotavirus"; + if (pathogen === "Campylobacter jejuni") return "campylobacter"; + return "cryptosporidium"; + } + + function pathogenToHistogramId(pathogen) { + return `source-fit-histogram-${pathogen.replaceAll(" ", "-")}`; + } + + function setPathogenMessage(pathogen, message, isError=false) { + const msgNode = document.querySelector(`[data-pathogen-message="${pathogen}"]`); + if (!msgNode) { return; } + msgNode.classList.toggle("text-danger", isError); + msgNode.classList.toggle("text-success", !isError); + msgNode.textContent = message; + } + + function renderSourceHistogram(pathogen, histogram) { + const targetId = pathogenToHistogramId(pathogen); + const target = document.getElementById(targetId); + if (!target) { return; } + Plotly.react(targetId, [{ + x: histogram.x, + y: histogram.y, + type: "bar", + marker: { color: "#4f5dff" }, + name: "Simulated count" + }], { + margin: {l: 35, r: 10, t: 20, b: 35}, + height: 180, + paper_bgcolor: "#ffffff", + plot_bgcolor: "#ffffff", + xaxis: {title: "Concentration"}, + yaxis: {title: "Frequency"} + }, {displaylogo: false, responsive: true}); + } + + async function bindSourceFitHandlers() { + const sourceForm = document.querySelector("#user-source-form"); + if (!sourceForm) { return; } + const csrfToken = document.querySelector('[name=csrfmiddlewaretoken]')?.value; + document.querySelectorAll("#create-source [data-pathogen-calc]").forEach(button => { + button.addEventListener("click", async () => { + const pathogen = button.getAttribute("data-pathogen-calc"); + const fileInput = document.querySelector(`#create-source [data-pathogen-file="${pathogen}"]`); + const file = fileInput?.files?.[0]; + if (!file) { + setPathogenMessage(pathogen, "Please select a CSV file first.", true); + return; + } + + const payload = new FormData(); + payload.append("pathogen", pathogen); + payload.append("file", file); + setPathogenMessage(pathogen, "Calculating distribution..."); + + try { + const response = await fetch("{% url 'source-inflow-fit' %}", { + method: "POST", + headers: {'X-CSRFToken': csrfToken}, + body: payload + }); + const result = await response.json(); + if (!response.ok) { + setPathogenMessage(pathogen, result.error || "Calculation failed.", true); + return; + } + + const prefix = pathogenToFieldPrefix(pathogen); + const minInput = sourceForm.querySelector(`#id_${prefix}_min`); + const maxInput = sourceForm.querySelector(`#id_${prefix}_max`); + if (minInput && maxInput) { + minInput.value = result.q025; + maxInput.value = result.q975; + minInput.dispatchEvent(new Event("change")); + maxInput.dispatchEvent(new Event("change")); + } + renderSourceHistogram(pathogen, result.histogram); + setPathogenMessage(pathogen, `Calculated from ${result.n_samples} values. Min=${result.q025}, Max=${result.q975}`); + } catch (err) { + setPathogenMessage(pathogen, "Could not calculate distribution. Please try again.", true); + } + }); + }); + } + document.addEventListener("DOMContentLoaded", function() { const userExposureForm = document.querySelector("#user-exposure-form"); if (null !== userExposureForm) { addSubmitHandler(userExposureForm) } @@ -186,6 +305,7 @@

Treatments

if (null !== userSourceForm) { addSubmitHandler(userSourceForm) } const userTreatmentForm = document.querySelector("#user-treatment-form"); if (null !== userTreatmentForm) { addSubmitHandler(userTreatmentForm) }; + bindSourceFitHandlers(); }) diff --git a/qmra/risk_assessment/tests/test_risk_assessment_api.py b/qmra/risk_assessment/tests/test_risk_assessment_api.py index 8cc9d63..dea6648 100644 --- a/qmra/risk_assessment/tests/test_risk_assessment_api.py +++ b/qmra/risk_assessment/tests/test_risk_assessment_api.py @@ -1 +1,50 @@ """test get, create, update, delete requests""" + +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from django.core.files.uploadedfile import SimpleUploadedFile + + +class SourceInflowFitApiTests(TestCase): + def setUp(self): + user_model = get_user_model() + self.user = user_model.objects.create_user(username="alice", password="secret123") + self.url = reverse("source-inflow-fit") + + def _upload(self, pathogen: str, csv_text: str): + self.client.force_login(self.user) + f = SimpleUploadedFile("sample.csv", csv_text.encode("utf-8"), content_type="text/csv") + return self.client.post(self.url, {"pathogen": pathogen, "file": f}) + + def test_fit_source_pathogen_distribution_success(self): + csv_text = "Rotavirus\n1\n2\n3\n8\n9\n18\n" + response = self._upload("Rotavirus", csv_text) + + self.assertEqual(response.status_code, 200) + payload = response.json() + self.assertEqual(payload["pathogen"], "Rotavirus") + self.assertIn("q025", payload) + self.assertIn("q975", payload) + self.assertIn("histogram", payload) + self.assertGreaterEqual(payload["q975"], payload["q025"]) + + def test_fit_source_pathogen_distribution_rejects_invalid_column(self): + csv_text = "WrongColumn\n1\n2\n3\n4\n5\n" + response = self._upload("Rotavirus", csv_text) + + self.assertEqual(response.status_code, 422) + self.assertIn("column named", response.json()["error"]) + + def test_fit_source_pathogen_distribution_rejects_non_integer_values(self): + csv_text = "Rotavirus\n1.2\n2\n3\n4\n5\n" + response = self._upload("Rotavirus", csv_text) + + self.assertEqual(response.status_code, 422) + self.assertIn("integers", response.json()["error"]) + + def test_fit_source_pathogen_distribution_requires_authentication(self): + f = SimpleUploadedFile("sample.csv", b"Rotavirus\n1\n2\n3\n4\n5\n", content_type="text/csv") + response = self.client.post(self.url, {"pathogen": "Rotavirus", "file": f}) + + self.assertEqual(response.status_code, 302) diff --git a/qmra/risk_assessment/urls.py b/qmra/risk_assessment/urls.py index f6b8bcb..ed8e0a6 100644 --- a/qmra/risk_assessment/urls.py +++ b/qmra/risk_assessment/urls.py @@ -48,6 +48,11 @@ views.create_source, name="source" ), + path( + "source/inflow-fit", + views.fit_source_pathogen_distribution, + name="source-inflow-fit" + ), path( "sources", views.list_sources, diff --git a/qmra/risk_assessment/views.py b/qmra/risk_assessment/views.py index eb9ec92..2a6e83c 100644 --- a/qmra/risk_assessment/views.py +++ b/qmra/risk_assessment/views.py @@ -1,4 +1,5 @@ import io +import math from crispy_forms.utils import render_crispy_form from django.contrib.auth.decorators import login_required @@ -16,6 +17,69 @@ from qmra.risk_assessment.user_models import UserExposureForm, UserTreatmentForm, UserSourceForm, UserExposure, \ UserSource, UserTreatment +import numpy as np +import pandas as pd + + +ALLOWED_SOURCE_PATHOGENS = { + "Rotavirus", + "Campylobacter jejuni", + "Cryptosporidium parvum", +} + + +def _bad_fit_response(message: str, status=422): + return JsonResponse({"error": message}, status=status) + + +def _fit_negative_binomial_from_series(series: pd.Series): + if series.empty: + raise ValueError("The pathogen column has no values.") + + numeric = pd.to_numeric(series, errors="coerce") + if numeric.isna().any(): + raise ValueError("All measurements must be integers.") + if (numeric < 0).any(): + raise ValueError("All measurements must be non-negative integers.") + if not ((numeric % 1) == 0).all(): + raise ValueError("All measurements must be integers.") + + values = numeric.astype(np.int64).to_numpy() + if values.size < 5: + raise ValueError("At least 5 measurements are required to fit a negative binomial distribution.") + + mu = float(values.mean()) + variance = float(values.var(ddof=1)) + if variance <= mu: + raise ValueError("Negative binomial fit failed because variance must be greater than the mean.") + + r = (mu ** 2) / (variance - mu) + p = r / (r + mu) + + if r <= 0 or p <= 0 or p >= 1: + raise ValueError("Negative binomial fit failed due to invalid fitted parameters.") + + simulated = np.random.default_rng(7).negative_binomial(r, p, size=5000) + q025 = float(np.quantile(simulated, 0.025)) + q975 = float(np.quantile(simulated, 0.975)) + + bins = np.arange(simulated.min(), simulated.max() + 2) - 0.5 + hist_counts, hist_edges = np.histogram(simulated, bins=bins) + centers = ((hist_edges[:-1] + hist_edges[1:]) / 2).astype(int) + + return { + "n_samples": int(values.size), + "mu": mu, + "variance": variance, + "r": float(r), + "p": float(p), + "q025": float(math.floor(q025)), + "q975": float(math.ceil(q975)), + "histogram": { + "x": centers.tolist(), + "y": hist_counts.tolist(), + } + } @transaction.atomic @@ -265,6 +329,41 @@ def list_sources(request): return JsonResponse({s["name"]: s for s in UserSource.objects.filter(user=request.user).values().all()}) +@login_required(login_url="/login") +def fit_source_pathogen_distribution(request): + if request.method != "POST": + return HttpResponse(status=404) + + pathogen = request.POST.get("pathogen") + if pathogen not in ALLOWED_SOURCE_PATHOGENS: + return _bad_fit_response("Unsupported pathogen selected.") + + csv_file = request.FILES.get("file") + if csv_file is None: + return _bad_fit_response("Please upload a CSV file.") + if csv_file.size > 5 * 1024 * 1024: + return _bad_fit_response("Uploaded file is too large. Maximum allowed size is 5 MB.") + + try: + content = csv_file.read().decode("utf-8") + df = pd.read_csv(io.StringIO(content)) + except Exception: + return _bad_fit_response("Could not read the CSV file. Please upload a valid UTF-8 CSV.") + + if pathogen not in df.columns: + return _bad_fit_response(f"CSV must include a column named '{pathogen}'.") + + try: + fit = _fit_negative_binomial_from_series(df[pathogen].dropna()) + except ValueError as err: + return _bad_fit_response(str(err)) + + return JsonResponse({ + "pathogen": pathogen, + **fit, + }) + + def list_inflows(request): if not request.user.is_authenticated: return JsonResponse({})