Skip to content

Commit 5ed7fb1

Browse files
committed
CNAM-154 Important fixes in MLPP featuring
- Trackloss and diagnostic are no longer stop conditions for the featuring - Added the option to remove patients who had a single trackloss
1 parent f384aa6 commit 5ed7fb1

8 files changed

Lines changed: 149 additions & 34 deletions

File tree

src/main/resources/config/filtering-default.conf

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ default = {
5151
start_delay = 0
5252
purchases_window = 0
5353
only_first = false
54+
filter_lost_patients = false
5455
filter_diagnosed_patients = true
56+
diagnosed_patients_threshold = 0
5557
filter_delayed_entries = true
5658
delayed_entry_threshold = 12
5759
}

src/main/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPConfig.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ object MLPPConfig {
1313
startDelay: Int,
1414
purchasesWindow: Int,
1515
onlyFirst: Boolean,
16+
filterLostPatients: Boolean,
1617
filterDiagnosedPatients: Boolean,
18+
diagnosedPatientsThreshold: Int,
1719
filterDelayedEntries: Boolean,
1820
delayedEntryThreshold: Int
1921
)
@@ -30,7 +32,9 @@ object MLPPConfig {
3032
startDelay = conf.getInt("exposures.start_delay"),
3133
purchasesWindow = conf.getInt("exposures.purchases_window"),
3234
onlyFirst = conf.getBoolean("exposures.only_first"),
35+
filterLostPatients = conf.getBoolean("exposures.filter_lost_patients"),
3336
filterDiagnosedPatients = conf.getBoolean("exposures.filter_diagnosed_patients"),
37+
diagnosedPatientsThreshold = conf.getInt("exposures.diagnosed_patients_threshold"),
3438
filterDelayedEntries = conf.getBoolean("exposures.filter_delayed_entries"),
3539
delayedEntryThreshold = conf.getInt("exposures.delayed_entry_threshold")
3640
)

src/main/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPExposuresTransformer.scala

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ object MLPPExposuresTransformer extends ExposuresTransformer {
1212
private def startDelay = MLPPConfig.exposureDefinition.startDelay
1313
private def purchasesWindow = MLPPConfig.exposureDefinition.purchasesWindow
1414
private def onlyFirstExposure = MLPPConfig.exposureDefinition.onlyFirst
15+
private def filterLostPatients = MLPPConfig.exposureDefinition.filterLostPatients
1516
private def filterDelayedEntries = MLPPConfig.exposureDefinition.filterDelayedEntries
1617
private def delayedEntryThreshold = MLPPConfig.exposureDefinition.delayedEntryThreshold
1718
private def filterDiagnosedPatients = MLPPConfig.exposureDefinition.filterDiagnosedPatients
19+
private def diagnosedPatientsThreshold = MLPPConfig.exposureDefinition.diagnosedPatientsThreshold
1820

1921
val outputColumns = List(
2022
col("patientID"),
@@ -31,17 +33,22 @@ object MLPPExposuresTransformer extends ExposuresTransformer {
3133
implicit class ExposuresDataFrame(data: DataFrame) {
3234

3335
/**
34-
* Drops patients whose got a target disease before periodStart
36+
* Drops patients whose got a target disease before periodStart + delay (default = 0)
3537
*/
3638
def filterDiagnosedPatients(doFilter: Boolean): DataFrame = {
3739

3840
if (doFilter) {
3941
val window = Window.partitionBy("patientID")
42+
43+
val dateThreshold: Column = add_months(
44+
lit(StudyStart), diagnosedPatientsThreshold
45+
).cast(TimestampType)
46+
4047
val filterColumn: Column = min(
4148
when(
42-
col("category") === "disease" &&
43-
col("eventId") === "targetDisease" &&
44-
(col("start") < StudyStart), lit(0)
49+
(col("category") === "disease") &&
50+
(col("eventId") === "targetDisease") &&
51+
(col("start") < dateThreshold), lit(0)
4552
).otherwise(lit(1))
4653
).over(window).cast(BooleanType)
4754

@@ -53,7 +60,7 @@ object MLPPExposuresTransformer extends ExposuresTransformer {
5360
}
5461

5562
/**
56-
* Drops patients whose first molecule event is after StudyStart + 1 year
63+
* Drops patients whose first molecule event is after StudyStart + delay (default: 1 year)
5764
*/
5865
def filterDelayedEntries(doFilter: Boolean): DataFrame = {
5966
if (doFilter) {
@@ -77,6 +84,26 @@ object MLPPExposuresTransformer extends ExposuresTransformer {
7784
}
7885
}
7986

87+
/**
88+
* Drops patients with trackloss events
89+
*/
90+
def filterLostPatients(doFilter: Boolean): DataFrame = {
91+
if (doFilter) {
92+
val window = Window.partitionBy("patientID")
93+
val filterColumn: Column = min(
94+
when(
95+
col("category") === "trackloss" && (col("start") >= StudyStart),
96+
lit(0)
97+
).otherwise(lit(1))
98+
).over(window).cast(BooleanType)
99+
100+
data.withColumn("filter", filterColumn).where(col("filter")).drop("filter")
101+
}
102+
else {
103+
data
104+
}
105+
}
106+
80107
def withExposureStart(minPurchases: Int = 1, intervalSize: Int = 6,
81108
startDelay: Int = 0, firstOnly: Boolean = false): DataFrame = {
82109

@@ -116,6 +143,7 @@ object MLPPExposuresTransformer extends ExposuresTransformer {
116143
input.toDF
117144
.filterDelayedEntries(filterDelayedEntries)
118145
.filterDiagnosedPatients(filterDiagnosedPatients)
146+
.filterLostPatients(filterLostPatients)
119147
.where(col("category") === "molecule")
120148
.withExposureStart(
121149
minPurchases = minPurchases,

src/main/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPMain.scala

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
package fr.polytechnique.cmap.cnam.filtering.mlpp
22

3-
import org.apache.spark.sql.Dataset
3+
import org.apache.spark.sql.{DataFrame, Dataset}
44
import org.apache.spark.sql.hive.HiveContext
5+
import org.apache.spark.sql.functions._
56
import fr.polytechnique.cmap.cnam.Main
6-
import fr.polytechnique.cmap.cnam.filtering.{FilteringConfig, FilteringMain, FlatEvent}
7+
import fr.polytechnique.cmap.cnam.filtering._
78

89
object MLPPMain extends Main {
910

1011
override def appName: String = "MLPPFeaturing"
1112

1213
def run(sqlContext: HiveContext, argsMap: Map[String, String] = Map()): Option[Dataset[MLPPFeature]] = {
1314

15+
import sqlContext.implicits._
16+
1417
// "get" returns an Option, then we can use foreach to gently ignore when the key was not found.
1518
argsMap.get("conf").foreach(sqlContext.setConf("conf", _))
1619
argsMap.get("env").foreach(sqlContext.setConf("env", _))
@@ -20,7 +23,23 @@ object MLPPMain extends Main {
2023
.filter(e => e.category == "molecule" || e.category == "disease").cache()
2124

2225
val diseaseEvents: Dataset[FlatEvent] = flatEvents.filter(_.category == "disease")
23-
val exposures: Dataset[FlatEvent] = MLPPExposuresTransformer.transform(flatEvents)
26+
val dcirFlat: DataFrame = sqlContext.read.parquet(FilteringConfig.inputPaths.dcir)
27+
28+
val patients: Dataset[Patient] = flatEvents.map(
29+
e => Patient(e.patientID, e.gender, e.birthDate, e.deathDate)
30+
).distinct
31+
val tracklossEvents: Dataset[Event] = TrackLossTransformer.transform(
32+
Sources(dcir=Some(dcirFlat))
33+
)
34+
val tracklossFlatEvents = tracklossEvents
35+
.as("left")
36+
.joinWith(patients.as("right"), col("left.patientID") === col("right.patientID"))
37+
.map((FlatEvent.merge _).tupled)
38+
.cache()
39+
40+
val allEvents = flatEvents.union(tracklossFlatEvents)
41+
42+
val exposures: Dataset[FlatEvent] = MLPPExposuresTransformer.transform(allEvents)
2443

2544
val mlppParams = MLPPWriter.Params(
2645
bucketSize = MLPPConfig.bucketSize,

src/main/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPWriter.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,18 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
6666

6767
val hadDisease: Column = (col("category") === "disease") &&
6868
(col("eventId") === "targetDisease") &&
69-
(col("startBucket") < minColumn(col("tracklossBucket"), col("deathBucket"), lit(bucketCount)))
69+
(col("startBucket") < minColumn(col("deathBucket"), lit(bucketCount)))
7070

7171
val diseaseBucket: Column = min(when(hadDisease, col("startBucket"))).over(window)
7272

7373
data.withColumn("diseaseBucket", diseaseBucket)
7474
}
7575

76+
// We are no longer using trackloss and disease information for calculating the end bucket.
7677
def withEndBucket: DataFrame = {
7778

7879
val endBucket: Column = minColumn(
79-
col("tracklossBucket"), col("diseaseBucket"), col("deathBucket"), lit(bucketCount)
80+
col("deathBucket"), lit(bucketCount)
8081
)
8182
data.withColumn("endBucket", endBucket)
8283
}
@@ -267,7 +268,6 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
267268
.withAge(AgeReferenceDate)
268269
.withStartBucket
269270
.withDeathBucket
270-
.withTracklossBucket
271271
.withDiseaseBucket
272272
.withEndBucket
273273
.where(col("category") === "exposure")

src/test/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPExposuresTransformerSuite.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class MLPPExposuresTransformerSuite extends SharedContext {
6565
("Patient_A", "molecule", "", makeTS(2008, 1, 10)),
6666
("Patient_A", "disease", "targetDisease", makeTS(2005, 1, 1)),
6767
("Patient_B", "molecule", "", makeTS(2009, 1, 1)),
68-
("Patient_B", "disease", "targetDisease", makeTS(2009, 1, 1)),
68+
("Patient_B", "disease", "targetDisease", makeTS(2006, 1, 1)),
6969
("Patient_C", "molecule", "", makeTS(2006, 1, 1))
7070
).toDF("patientID", "category", "eventId", "start")
7171

@@ -81,6 +81,8 @@ class MLPPExposuresTransformerSuite extends SharedContext {
8181

8282
// Then
8383
import RichDataFrames._
84+
result.show
85+
expected.show
8486
assert(result === expected)
8587
}
8688

@@ -105,6 +107,34 @@ class MLPPExposuresTransformerSuite extends SharedContext {
105107
assert(result === expected)
106108
}
107109

110+
"filterLostPatients" should "remove patients when they have a trackloss events" in {
111+
val sqlCtx = sqlContext
112+
import sqlCtx.implicits._
113+
114+
// Given
115+
val input = Seq(
116+
("Patient_A", "molecule", makeTS(2006, 1, 1)),
117+
("Patient_A", "molecule", makeTS(2006, 2, 1)),
118+
("Patient_B", "molecule", makeTS(2006, 5, 1)),
119+
("Patient_B", "trackloss", makeTS(2007, 1, 1)),
120+
("Patient_C", "molecule", makeTS(2006, 11, 1))
121+
).toDF("patientID", "category", "start")
122+
123+
val expected = Seq(
124+
("Patient_A", "molecule", makeTS(2006, 1, 1)),
125+
("Patient_A", "molecule", makeTS(2006, 2, 1)),
126+
("Patient_C", "molecule", makeTS(2006, 11, 1))
127+
).toDF("patientID", "category", "start")
128+
129+
// When
130+
import MLPPExposuresTransformer.ExposuresDataFrame
131+
val result = input.filterLostPatients(true)
132+
133+
// Then
134+
import RichDataFrames._
135+
assert(result === expected)
136+
}
137+
108138
"withExposureStart" should "add a column with the start of the default MLPP exposure definition" in {
109139
val sqlCtx = sqlContext
110140
import sqlCtx.implicits._

src/test/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPMainSuite.scala

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,26 @@ class MLPPMainSuite extends SharedContext {
2626
lazy val featuresPath = FilteringConfig.outputPaths.mlppFeatures
2727

2828
val expectedFeatures: DataFrame = Seq(
29-
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 0, 0, 0, 0, 1.0),
30-
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 1, 0, 1, 0, 1.0),
31-
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 1, 1, 1, 1, 1.0),
32-
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 1, 2, 1, 1.0),
33-
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 2, 2, 2, 1.0)
29+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 0, 0, 0, 0, 1.0),
30+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 1, 0, 1, 0, 1.0),
31+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 1, 1, 1, 1, 1.0),
32+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 1, 2, 1, 1.0),
33+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 2, 2, 2, 1.0),
34+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 3, 2, 3, 2, 1.0),
35+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 3, 3, 3, 3, 1.0),
36+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 4, 3, 4, 3, 1.0),
37+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 4, 4, 4, 4, 1.0),
38+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 5, 4, 5, 4, 1.0),
39+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 5, 5, 5, 5, 1.0),
40+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 6, 5, 6, 5, 1.0),
41+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 6, 6, 6, 6, 1.0),
42+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 7, 6, 7, 6, 1.0),
43+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 7, 7, 7, 7, 1.0),
44+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 8, 7, 8, 7, 1.0),
45+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 8, 8, 8, 8, 1.0),
46+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 9, 8, 9, 8, 1.0),
47+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 9, 9, 9, 9, 1.0),
48+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 10, 9, 10, 9, 1.0)
3449
).toDF
3550

3651
// When
@@ -56,7 +71,14 @@ class MLPPMainSuite extends SharedContext {
5671
val expectedFeatures: DataFrame = Seq(
5772
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 0, 0, 0, 0, 1.0),
5873
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 1, 1, 1, 1, 1.0),
59-
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 2, 2, 2, 1.0)
74+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 2, 2, 2, 2, 1.0),
75+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 3, 3, 3, 3, 1.0),
76+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 4, 4, 4, 4, 1.0),
77+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 5, 5, 5, 5, 1.0),
78+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 6, 6, 6, 6, 1.0),
79+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 7, 7, 7, 7, 1.0),
80+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 8, 8, 8, 8, 1.0),
81+
MLPPFeature("Patient_02", 0, "PIOGLITAZONE", 0, 9, 9, 9, 9, 1.0)
6082
).toDF
6183

6284
// When

src/test/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPWriterSuite.scala

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,14 @@ class MLPPWriterSuite extends SharedContext {
221221
val expected = Seq(
222222
("PA", Some(2)),
223223
("PA", Some(2)),
224-
("PB", Some(3)),
225-
("PB", Some(3)),
226-
("PC", Some(4)),
227-
("PC", Some(4)),
224+
("PB", Some(4)),
225+
("PB", Some(4)),
226+
("PC", Some(16)),
227+
("PC", Some(16)),
228228
("PD", Some(5)),
229229
("PD", Some(5)),
230-
("PE", Some(6)),
231-
("PE", Some(6)),
230+
("PE", Some(7)),
231+
("PE", Some(7)),
232232
("PF", Some(16))
233233
).toDF("patientID", "endBucket")
234234

@@ -605,16 +605,26 @@ class MLPPWriterSuite extends SharedContext {
605605

606606
val expectedFeatures = Seq(
607607
// Patient A
608-
MLPPFeature("PA", 0, "Mol1", 0, 0, 0, 0, 0, 1.0),
609-
MLPPFeature("PA", 0, "Mol1", 0, 1, 1, 1, 1, 1.0),
610-
MLPPFeature("PA", 0, "Mol1", 0, 2, 2, 2, 2, 1.0),
611-
MLPPFeature("PA", 0, "Mol1", 0, 3, 3, 3, 3, 1.0),
612-
MLPPFeature("PA", 0, "Mol1", 0, 2, 0, 2, 0, 1.0),
613-
MLPPFeature("PA", 0, "Mol1", 0, 3, 1, 3, 1, 1.0),
614-
MLPPFeature("PA", 0, "Mol1", 0, 3, 0, 3, 0, 1.0),
615-
MLPPFeature("PA", 0, "Mol2", 1, 2, 0, 2, 4, 1.0),
616-
MLPPFeature("PA", 0, "Mol2", 1, 3, 1, 3, 5, 1.0),
617-
MLPPFeature("PA", 0, "Mol3", 2, 3, 0, 3, 8, 1.0)
608+
MLPPFeature("PA", 0, "Mol1", 0, 0, 0, 0, 0, 1.0),
609+
MLPPFeature("PA", 0, "Mol1", 0, 1, 1, 1, 1, 1.0),
610+
MLPPFeature("PA", 0, "Mol1", 0, 2, 2, 2, 2, 1.0),
611+
MLPPFeature("PA", 0, "Mol1", 0, 3, 3, 3, 3, 1.0),
612+
MLPPFeature("PA", 0, "Mol1", 0, 2, 0, 2, 0, 1.0),
613+
MLPPFeature("PA", 0, "Mol1", 0, 3, 1, 3, 1, 1.0),
614+
MLPPFeature("PA", 0, "Mol1", 0, 4, 2, 4, 2, 1.0),
615+
MLPPFeature("PA", 0, "Mol1", 0, 5, 3, 5, 3, 1.0),
616+
MLPPFeature("PA", 0, "Mol1", 0, 3, 0, 3, 0, 1.0),
617+
MLPPFeature("PA", 0, "Mol1", 0, 4, 1, 4, 1, 1.0),
618+
MLPPFeature("PA", 0, "Mol1", 0, 5, 2, 5, 2, 1.0),
619+
MLPPFeature("PA", 0, "Mol1", 0, 6, 3, 6, 3, 1.0),
620+
MLPPFeature("PA", 0, "Mol2", 1, 2, 0, 2, 4, 1.0),
621+
MLPPFeature("PA", 0, "Mol2", 1, 3, 1, 3, 5, 1.0),
622+
MLPPFeature("PA", 0, "Mol2", 1, 4, 2, 4, 6, 1.0),
623+
MLPPFeature("PA", 0, "Mol2", 1, 5, 3, 5, 7, 1.0),
624+
MLPPFeature("PA", 0, "Mol3", 2, 3, 0, 3, 8, 1.0),
625+
MLPPFeature("PA", 0, "Mol3", 2, 4, 1, 4, 9, 1.0),
626+
MLPPFeature("PA", 0, "Mol3", 2, 5, 2, 5, 10, 1.0),
627+
MLPPFeature("PA", 0, "Mol3", 2, 6, 3, 6, 11, 1.0)
618628
).toDF
619629

620630
val expectedZMatrix = Seq(

0 commit comments

Comments
 (0)