diff --git a/qmra/risk_assessment/templates/assessment-configurator.html b/qmra/risk_assessment/templates/assessment-configurator.html
index 638cbcf..5c9ccda 100644
--- a/qmra/risk_assessment/templates/assessment-configurator.html
+++ b/qmra/risk_assessment/templates/assessment-configurator.html
@@ -130,6 +130,37 @@
Treatments
{% crispy user_source_form %}
+
{% crispy user_treatment_form %}
@@ -179,6 +210,94 @@
Treatments
form.querySelector("input[type=submit]").addEventListener("click", handleSubmit);
form.addEventListener("submit", function(){ return false; });
}
+
+ function pathogenToFieldPrefix(pathogen) {
+ if (pathogen === "Rotavirus") return "rotavirus";
+ if (pathogen === "Campylobacter jejuni") return "campylobacter";
+ return "cryptosporidium";
+ }
+
+ function pathogenToHistogramId(pathogen) {
+ return `source-fit-histogram-${pathogen.replaceAll(" ", "-")}`;
+ }
+
+ function setPathogenMessage(pathogen, message, isError=false) {
+ const msgNode = document.querySelector(`[data-pathogen-message="${pathogen}"]`);
+ if (!msgNode) { return; }
+ msgNode.classList.toggle("text-danger", isError);
+ msgNode.classList.toggle("text-success", !isError);
+ msgNode.textContent = message;
+ }
+
+ function renderSourceHistogram(pathogen, histogram) {
+ const targetId = pathogenToHistogramId(pathogen);
+ const target = document.getElementById(targetId);
+ if (!target) { return; }
+ Plotly.react(targetId, [{
+ x: histogram.x,
+ y: histogram.y,
+ type: "bar",
+ marker: { color: "#4f5dff" },
+ name: "Simulated count"
+ }], {
+ margin: {l: 35, r: 10, t: 20, b: 35},
+ height: 180,
+ paper_bgcolor: "#ffffff",
+ plot_bgcolor: "#ffffff",
+ xaxis: {title: "Concentration"},
+ yaxis: {title: "Frequency"}
+ }, {displaylogo: false, responsive: true});
+ }
+
+ async function bindSourceFitHandlers() {
+ const sourceForm = document.querySelector("#user-source-form");
+ if (!sourceForm) { return; }
+ const csrfToken = document.querySelector('[name=csrfmiddlewaretoken]')?.value;
+ document.querySelectorAll("#create-source [data-pathogen-calc]").forEach(button => {
+ button.addEventListener("click", async () => {
+ const pathogen = button.getAttribute("data-pathogen-calc");
+ const fileInput = document.querySelector(`#create-source [data-pathogen-file="${pathogen}"]`);
+ const file = fileInput?.files?.[0];
+ if (!file) {
+ setPathogenMessage(pathogen, "Please select a CSV file first.", true);
+ return;
+ }
+
+ const payload = new FormData();
+ payload.append("pathogen", pathogen);
+ payload.append("file", file);
+ setPathogenMessage(pathogen, "Calculating distribution...");
+
+ try {
+ const response = await fetch("{% url 'source-inflow-fit' %}", {
+ method: "POST",
+ headers: {'X-CSRFToken': csrfToken},
+ body: payload
+ });
+ const result = await response.json();
+ if (!response.ok) {
+ setPathogenMessage(pathogen, result.error || "Calculation failed.", true);
+ return;
+ }
+
+ const prefix = pathogenToFieldPrefix(pathogen);
+ const minInput = sourceForm.querySelector(`#id_${prefix}_min`);
+ const maxInput = sourceForm.querySelector(`#id_${prefix}_max`);
+ if (minInput && maxInput) {
+ minInput.value = result.q025;
+ maxInput.value = result.q975;
+ minInput.dispatchEvent(new Event("change"));
+ maxInput.dispatchEvent(new Event("change"));
+ }
+ renderSourceHistogram(pathogen, result.histogram);
+ setPathogenMessage(pathogen, `Calculated from ${result.n_samples} values. Min=${result.q025}, Max=${result.q975}`);
+ } catch (err) {
+ setPathogenMessage(pathogen, "Could not calculate distribution. Please try again.", true);
+ }
+ });
+ });
+ }
+
document.addEventListener("DOMContentLoaded", function() {
const userExposureForm = document.querySelector("#user-exposure-form");
if (null !== userExposureForm) { addSubmitHandler(userExposureForm) }
@@ -186,6 +305,7 @@ Treatments
if (null !== userSourceForm) { addSubmitHandler(userSourceForm) }
const userTreatmentForm = document.querySelector("#user-treatment-form");
if (null !== userTreatmentForm) { addSubmitHandler(userTreatmentForm) };
+ bindSourceFitHandlers();
})
diff --git a/qmra/risk_assessment/tests/test_risk_assessment_api.py b/qmra/risk_assessment/tests/test_risk_assessment_api.py
index 8cc9d63..dea6648 100644
--- a/qmra/risk_assessment/tests/test_risk_assessment_api.py
+++ b/qmra/risk_assessment/tests/test_risk_assessment_api.py
@@ -1 +1,50 @@
"""test get, create, update, delete requests"""
+
+from django.contrib.auth import get_user_model
+from django.test import TestCase
+from django.urls import reverse
+from django.core.files.uploadedfile import SimpleUploadedFile
+
+
+class SourceInflowFitApiTests(TestCase):
+ def setUp(self):
+ user_model = get_user_model()
+ self.user = user_model.objects.create_user(username="alice", password="secret123")
+ self.url = reverse("source-inflow-fit")
+
+ def _upload(self, pathogen: str, csv_text: str):
+ self.client.force_login(self.user)
+ f = SimpleUploadedFile("sample.csv", csv_text.encode("utf-8"), content_type="text/csv")
+ return self.client.post(self.url, {"pathogen": pathogen, "file": f})
+
+ def test_fit_source_pathogen_distribution_success(self):
+ csv_text = "Rotavirus\n1\n2\n3\n8\n9\n18\n"
+ response = self._upload("Rotavirus", csv_text)
+
+ self.assertEqual(response.status_code, 200)
+ payload = response.json()
+ self.assertEqual(payload["pathogen"], "Rotavirus")
+ self.assertIn("q025", payload)
+ self.assertIn("q975", payload)
+ self.assertIn("histogram", payload)
+ self.assertGreaterEqual(payload["q975"], payload["q025"])
+
+ def test_fit_source_pathogen_distribution_rejects_invalid_column(self):
+ csv_text = "WrongColumn\n1\n2\n3\n4\n5\n"
+ response = self._upload("Rotavirus", csv_text)
+
+ self.assertEqual(response.status_code, 422)
+ self.assertIn("column named", response.json()["error"])
+
+ def test_fit_source_pathogen_distribution_rejects_non_integer_values(self):
+ csv_text = "Rotavirus\n1.2\n2\n3\n4\n5\n"
+ response = self._upload("Rotavirus", csv_text)
+
+ self.assertEqual(response.status_code, 422)
+ self.assertIn("integers", response.json()["error"])
+
+ def test_fit_source_pathogen_distribution_requires_authentication(self):
+ f = SimpleUploadedFile("sample.csv", b"Rotavirus\n1\n2\n3\n4\n5\n", content_type="text/csv")
+ response = self.client.post(self.url, {"pathogen": "Rotavirus", "file": f})
+
+ self.assertEqual(response.status_code, 302)
diff --git a/qmra/risk_assessment/urls.py b/qmra/risk_assessment/urls.py
index f6b8bcb..ed8e0a6 100644
--- a/qmra/risk_assessment/urls.py
+++ b/qmra/risk_assessment/urls.py
@@ -48,6 +48,11 @@
views.create_source,
name="source"
),
+ path(
+ "source/inflow-fit",
+ views.fit_source_pathogen_distribution,
+ name="source-inflow-fit"
+ ),
path(
"sources",
views.list_sources,
diff --git a/qmra/risk_assessment/views.py b/qmra/risk_assessment/views.py
index eb9ec92..2a6e83c 100644
--- a/qmra/risk_assessment/views.py
+++ b/qmra/risk_assessment/views.py
@@ -1,4 +1,5 @@
import io
+import math
from crispy_forms.utils import render_crispy_form
from django.contrib.auth.decorators import login_required
@@ -16,6 +17,69 @@
from qmra.risk_assessment.user_models import UserExposureForm, UserTreatmentForm, UserSourceForm, UserExposure, \
UserSource, UserTreatment
+import numpy as np
+import pandas as pd
+
+
+ALLOWED_SOURCE_PATHOGENS = {
+ "Rotavirus",
+ "Campylobacter jejuni",
+ "Cryptosporidium parvum",
+}
+
+
+def _bad_fit_response(message: str, status=422):
+ return JsonResponse({"error": message}, status=status)
+
+
+def _fit_negative_binomial_from_series(series: pd.Series):
+ if series.empty:
+ raise ValueError("The pathogen column has no values.")
+
+ numeric = pd.to_numeric(series, errors="coerce")
+ if numeric.isna().any():
+ raise ValueError("All measurements must be integers.")
+ if (numeric < 0).any():
+ raise ValueError("All measurements must be non-negative integers.")
+ if not ((numeric % 1) == 0).all():
+ raise ValueError("All measurements must be integers.")
+
+ values = numeric.astype(np.int64).to_numpy()
+ if values.size < 5:
+ raise ValueError("At least 5 measurements are required to fit a negative binomial distribution.")
+
+ mu = float(values.mean())
+ variance = float(values.var(ddof=1))
+ if variance <= mu:
+ raise ValueError("Negative binomial fit failed because variance must be greater than the mean.")
+
+ r = (mu ** 2) / (variance - mu)
+ p = r / (r + mu)
+
+ if r <= 0 or p <= 0 or p >= 1:
+ raise ValueError("Negative binomial fit failed due to invalid fitted parameters.")
+
+ simulated = np.random.default_rng(7).negative_binomial(r, p, size=5000)
+ q025 = float(np.quantile(simulated, 0.025))
+ q975 = float(np.quantile(simulated, 0.975))
+
+ bins = np.arange(simulated.min(), simulated.max() + 2) - 0.5
+ hist_counts, hist_edges = np.histogram(simulated, bins=bins)
+ centers = ((hist_edges[:-1] + hist_edges[1:]) / 2).astype(int)
+
+ return {
+ "n_samples": int(values.size),
+ "mu": mu,
+ "variance": variance,
+ "r": float(r),
+ "p": float(p),
+ "q025": float(math.floor(q025)),
+ "q975": float(math.ceil(q975)),
+ "histogram": {
+ "x": centers.tolist(),
+ "y": hist_counts.tolist(),
+ }
+ }
@transaction.atomic
@@ -265,6 +329,41 @@ def list_sources(request):
return JsonResponse({s["name"]: s for s in UserSource.objects.filter(user=request.user).values().all()})
+@login_required(login_url="/login")
+def fit_source_pathogen_distribution(request):
+ if request.method != "POST":
+ return HttpResponse(status=404)
+
+ pathogen = request.POST.get("pathogen")
+ if pathogen not in ALLOWED_SOURCE_PATHOGENS:
+ return _bad_fit_response("Unsupported pathogen selected.")
+
+ csv_file = request.FILES.get("file")
+ if csv_file is None:
+ return _bad_fit_response("Please upload a CSV file.")
+ if csv_file.size > 5 * 1024 * 1024:
+ return _bad_fit_response("Uploaded file is too large. Maximum allowed size is 5 MB.")
+
+ try:
+ content = csv_file.read().decode("utf-8")
+ df = pd.read_csv(io.StringIO(content))
+ except Exception:
+ return _bad_fit_response("Could not read the CSV file. Please upload a valid UTF-8 CSV.")
+
+ if pathogen not in df.columns:
+ return _bad_fit_response(f"CSV must include a column named '{pathogen}'.")
+
+ try:
+ fit = _fit_negative_binomial_from_series(df[pathogen].dropna())
+ except ValueError as err:
+ return _bad_fit_response(str(err))
+
+ return JsonResponse({
+ "pathogen": pathogen,
+ **fit,
+ })
+
+
def list_inflows(request):
if not request.user.is_authenticated:
return JsonResponse({})