Skip to content

Commit 19f3c38

Browse files
authored
Merge pull request #56 from X-DataInitiative/CNAM-149-MLPPExposuresTransformer
CNAM-149-MLPPExposuresTransformer
2 parents 50fc016 + be20702 commit 19f3c38

4 files changed

Lines changed: 350 additions & 2 deletions

File tree

src/main/scala/fr/polytechnique/cmap/cnam/filtering/ExposuresTransformer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ trait ExposuresTransformer extends DatasetTransformer[FlatEvent, FlatEvent] {
77

88
// Constant definitions. Should be verified before compiling.
99
// In the future, we may want to export them to an external file.
10-
val periodStart = makeTS(2006, 1, 1)
10+
val StudyStart = makeTS(2006, 1, 1)
1111

1212
def transform(input: Dataset[FlatEvent]): Dataset[FlatEvent]
1313
}

src/main/scala/fr/polytechnique/cmap/cnam/filtering/cox/CoxExposuresTransformer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ object CoxExposuresTransformer extends ExposuresTransformer {
4141
).over(window).cast(BooleanType)
4242

4343
// Drop patients whose first molecule event is after PeriodStart + 1 year
44-
val firstYearObservation = add_months(lit(periodStart), 12).cast(TimestampType)
44+
val firstYearObservation = add_months(lit(StudyStart), 12).cast(TimestampType)
4545
val drugFilter = max(
4646
when(
4747
col("category") === "molecule" && (col("start") <= firstYearObservation),
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package fr.polytechnique.cmap.cnam.filtering.mlpp
2+
3+
import org.apache.spark.sql.expressions.Window
4+
import org.apache.spark.sql.functions._
5+
import org.apache.spark.sql.types.{BooleanType, TimestampType}
6+
import org.apache.spark.sql.{Column, DataFrame, Dataset}
7+
import fr.polytechnique.cmap.cnam.filtering.{ExposuresTransformer, FlatEvent}
8+
9+
object MLPPExposuresTransformer extends ExposuresTransformer {
10+
11+
final val ExposureMinPurchases = 1
12+
final val ExposureStartDelay = 0
13+
final val ExposureStartThreshold = 6
14+
final val OnlyFirstExposure = false
15+
final val FilterDelayedEntries = true
16+
final val FilterDiagnosedPatients = true
17+
18+
val outputColumns = List(
19+
col("patientID"),
20+
col("gender"),
21+
col("birthDate"),
22+
col("deathDate"),
23+
lit("exposure").as("category"),
24+
col("eventId"),
25+
lit(1.0).as("weight"),
26+
col("exposureStart").as("start"),
27+
col("exposureEnd").as("end")
28+
)
29+
30+
implicit class ExposuresDataFrame(data: DataFrame) {
31+
32+
/**
33+
* Drops patients whose got a target disease before periodStart
34+
*/
35+
def filterDiagnosedPatients(doFilter: Boolean): DataFrame = doFilter match {
36+
case false => data
37+
case true => {
38+
val window = Window.partitionBy("patientID")
39+
40+
val filterColumn: Column = min(
41+
when(
42+
col("category") === "disease" &&
43+
col("eventId") === "targetDisease" &&
44+
(col("start") < StudyStart), lit(0)
45+
).otherwise(lit(1))
46+
).over(window).cast(BooleanType)
47+
48+
data.withColumn("filter", filterColumn).where(col("filter")).drop("filter")
49+
}
50+
}
51+
52+
/**
53+
* Drops patients whose first molecule event is after StudyStart + 1 year
54+
*/
55+
def filterDelayedEntries(doFilter: Boolean): DataFrame = doFilter match {
56+
case false => data
57+
case true => {
58+
val window = Window.partitionBy("patientID")
59+
60+
val firstYearObservation: Column = add_months(lit(StudyStart), 12).cast(TimestampType)
61+
val filterColumn: Column = max(
62+
when(
63+
col("category") === "molecule" && (col("start") <= firstYearObservation),
64+
lit(1)
65+
).otherwise(lit(0))
66+
).over(window).cast(BooleanType)
67+
68+
data.withColumn("filter", filterColumn).where(col("filter")).drop("filter")
69+
}
70+
}
71+
72+
def withExposureStart(minPurchases: Int = 1, intervalSize: Int = 6,
73+
startDelay: Int = 0, firstOnly: Boolean = false): DataFrame = {
74+
75+
val window = Window.partitionBy("patientID", "eventId")
76+
77+
// We don't lag the column if we want one exposure for every purchase
78+
val laggedStart: Column = if(minPurchases == 1)
79+
col("start")
80+
else
81+
lag(col("start"), minPurchases - 1).over(window.orderBy("start"))
82+
83+
val exposureStartRule: Column = when(
84+
months_between(col("start"), col("previousStartDate")) <= intervalSize,
85+
add_months(col("start"), startDelay).cast(TimestampType)
86+
)
87+
88+
val finalExposureStart: Column = if(firstOnly)
89+
min(exposureStartRule).over(window)
90+
else
91+
exposureStartRule
92+
93+
data
94+
.withColumn("previousStartDate", laggedStart)
95+
.withColumn("exposureStart", finalExposureStart)
96+
}
97+
98+
// For now, exposure end is the same as exposure start
99+
def withExposureEnd: DataFrame = {
100+
data.withColumn("exposureEnd", col("exposureStart"))
101+
}
102+
}
103+
104+
def transform(input: Dataset[FlatEvent]): Dataset[FlatEvent] = {
105+
import input.sqlContext.implicits._
106+
107+
input.toDF
108+
.filterDelayedEntries(FilterDelayedEntries)
109+
.filterDiagnosedPatients(FilterDiagnosedPatients)
110+
.where(col("category") === "molecule")
111+
.withExposureStart(
112+
minPurchases = ExposureMinPurchases,
113+
intervalSize = ExposureStartThreshold,
114+
startDelay = ExposureStartDelay,
115+
firstOnly = OnlyFirstExposure
116+
)
117+
.withExposureEnd
118+
.where(col("exposureStart").isNotNull)
119+
.select(outputColumns: _*)
120+
.dropDuplicates(Seq("patientID", "eventId", "start", "end"))
121+
.as[FlatEvent]
122+
}
123+
}
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
package fr.polytechnique.cmap.cnam.filtering.mlpp
2+
3+
import fr.polytechnique.cmap.cnam.SharedContext
4+
import fr.polytechnique.cmap.cnam.filtering.FlatEvent
5+
import fr.polytechnique.cmap.cnam.utilities.RichDataFrames
6+
import fr.polytechnique.cmap.cnam.utilities.functions._
7+
8+
class MLPPExposuresTransformerSuite extends SharedContext {
9+
10+
"filterDelayedEntries" should "keep only patients who purchased a medicine in the first year of the study" in {
11+
val sqlCtx = sqlContext
12+
import sqlCtx.implicits._
13+
14+
// Given
15+
val input = Seq(
16+
("Patient_A", "molecule", "", makeTS(2008, 1, 1)),
17+
("Patient_A", "molecule", "", makeTS(2008, 2, 1)),
18+
("Patient_B", "molecule", "", makeTS(2009, 1, 1)),
19+
("Patient_C", "molecule", "", makeTS(2006, 2, 1)),
20+
("Patient_C", "molecule", "", makeTS(2006, 1, 1))
21+
).toDF("patientID", "category", "eventId", "start")
22+
23+
val expected = Seq(
24+
("Patient_C", "molecule"),
25+
("Patient_C", "molecule")
26+
).toDF("patientID", "category")
27+
28+
// When
29+
import MLPPExposuresTransformer.ExposuresDataFrame
30+
val result = input.filterDelayedEntries(true).select("patientID", "category")
31+
32+
// Then
33+
import RichDataFrames._
34+
assert(result === expected)
35+
}
36+
37+
it should "return the same data if we pass false" in {
38+
val sqlCtx = sqlContext
39+
import sqlCtx.implicits._
40+
41+
// Given
42+
val input = Seq(
43+
("Patient_A", "molecule", "", makeTS(2008, 1, 1)),
44+
("Patient_B", "molecule", "", makeTS(2009, 1, 1)),
45+
("Patient_C", "molecule", "", makeTS(2006, 1, 1))
46+
).toDF("patientID", "category", "eventId", "start")
47+
48+
val expected = input
49+
50+
// When
51+
import MLPPExposuresTransformer.ExposuresDataFrame
52+
val result = input.filterDelayedEntries(false)
53+
54+
// Then
55+
import RichDataFrames._
56+
assert(result === expected)
57+
}
58+
59+
"filterDiagnosedPatients" should "keep only patients who did not have a target disease before the study start" in {
60+
val sqlCtx = sqlContext
61+
import sqlCtx.implicits._
62+
63+
// Given
64+
val input = Seq(
65+
("Patient_A", "molecule", "", makeTS(2008, 1, 10)),
66+
("Patient_A", "disease", "targetDisease", makeTS(2005, 1, 1)),
67+
("Patient_B", "molecule", "", makeTS(2009, 1, 1)),
68+
("Patient_B", "disease", "targetDisease", makeTS(2009, 1, 1)),
69+
("Patient_C", "molecule", "", makeTS(2006, 1, 1))
70+
).toDF("patientID", "category", "eventId", "start")
71+
72+
val expected = Seq(
73+
("Patient_B", "molecule"),
74+
("Patient_B", "disease"),
75+
("Patient_C", "molecule")
76+
).toDF("patientID", "category")
77+
78+
// When
79+
import MLPPExposuresTransformer.ExposuresDataFrame
80+
val result = input.filterDiagnosedPatients(true).select("patientID", "category")
81+
82+
// Then
83+
import RichDataFrames._
84+
assert(result === expected)
85+
}
86+
87+
it should "return the same data if we pass false" in {
88+
val sqlCtx = sqlContext
89+
import sqlCtx.implicits._
90+
91+
// Given
92+
val input = Seq(
93+
("Patient_A", "molecule", "", makeTS(2008, 1, 10)),
94+
("Patient_A", "disease", "targetDisease", makeTS(2007, 1, 1))
95+
).toDF("patientID", "category", "eventId", "start")
96+
97+
val expected = input
98+
99+
// When
100+
import MLPPExposuresTransformer.ExposuresDataFrame
101+
val result = input.filterDiagnosedPatients(false)
102+
103+
// Then
104+
import RichDataFrames._
105+
assert(result === expected)
106+
}
107+
108+
"withExposureStart" should "add a column with the start of the default MLPP exposure definition" in {
109+
val sqlCtx = sqlContext
110+
import sqlCtx.implicits._
111+
112+
// Given
113+
val input = Seq(
114+
("Patient_A", "molecule", "PIOGLITAZONE", makeTS(2008, 1, 1)),
115+
("Patient_A", "molecule", "PIOGLITAZONE", makeTS(2008, 2, 1)),
116+
("Patient_A", "molecule", "SULFONYLUREA", makeTS(2008, 3, 1)),
117+
("Patient_B", "molecule", "PIOGLITAZONE", makeTS(2008, 4, 1)),
118+
("Patient_B", "molecule", "BENFLUOREX", makeTS(2008, 5, 1))
119+
).toDF("PatientID", "category", "eventId", "start")
120+
121+
val expected = Seq(
122+
("Patient_A", "PIOGLITAZONE", Some(makeTS(2008, 1, 1))),
123+
("Patient_A", "PIOGLITAZONE", Some(makeTS(2008, 2, 1))),
124+
("Patient_A", "SULFONYLUREA", Some(makeTS(2008, 3, 1))),
125+
("Patient_B", "PIOGLITAZONE", Some(makeTS(2008, 4, 1))),
126+
("Patient_B", "BENFLUOREX", Some(makeTS(2008, 5, 1)))
127+
).toDF("PatientID", "eventId", "exposureStart")
128+
129+
// When
130+
import MLPPExposuresTransformer.ExposuresDataFrame
131+
val result = input.withExposureStart(minPurchases = 1, firstOnly = false)
132+
.select("PatientID", "eventId", "exposureStart")
133+
134+
// Then
135+
import RichDataFrames._
136+
result.show
137+
expected.show
138+
assert(result === expected)
139+
}
140+
141+
it should "add a column with the start of the exposure, using a 'cox-like' definition" in {
142+
val sqlCtx = sqlContext
143+
import sqlCtx.implicits._
144+
145+
// Given
146+
val input = Seq(
147+
("Patient_A", "molecule", "PIOGLITAZONE", makeTS(2008, 6, 1)),
148+
("Patient_A", "molecule", "PIOGLITAZONE", makeTS(2008, 1, 1)),
149+
("Patient_A", "molecule", "PIOGLITAZONE", makeTS(2008, 8, 1)),
150+
("Patient_A", "molecule", "PIOGLITAZONE", makeTS(2008, 10, 1)),
151+
("Patient_A", "molecule", "PIOGLITAZONE", makeTS(2008, 11, 1)),
152+
("Patient_A", "molecule", "SULFONYLUREA", makeTS(2008, 9, 1)),
153+
("Patient_A", "molecule", "SULFONYLUREA", makeTS(2008, 10, 1)),
154+
("Patient_B", "molecule", "PIOGLITAZONE", makeTS(2009, 1, 1))
155+
).toDF("PatientID", "category", "eventId", "start")
156+
157+
val expected = Seq(
158+
("Patient_A", "PIOGLITAZONE", Some(makeTS(2008, 6, 1))),
159+
("Patient_A", "PIOGLITAZONE", Some(makeTS(2008, 6, 1))),
160+
("Patient_A", "PIOGLITAZONE", Some(makeTS(2008, 6, 1))),
161+
("Patient_A", "PIOGLITAZONE", Some(makeTS(2008, 6, 1))),
162+
("Patient_A", "PIOGLITAZONE", Some(makeTS(2008, 6, 1))),
163+
("Patient_A", "SULFONYLUREA", Some(makeTS(2008, 10, 1))),
164+
("Patient_A", "SULFONYLUREA", Some(makeTS(2008, 10, 1))),
165+
("Patient_B", "PIOGLITAZONE", None)
166+
).toDF("PatientID", "eventId", "exposureStart")
167+
168+
169+
// When
170+
import MLPPExposuresTransformer.ExposuresDataFrame
171+
val result = input.withExposureStart(
172+
minPurchases = 2, intervalSize = 6, startDelay = 0, firstOnly = true
173+
).select("PatientID", "eventId", "exposureStart")
174+
175+
// Then
176+
import RichDataFrames._
177+
result.show
178+
expected.show
179+
assert(result === expected)
180+
}
181+
182+
"transform" should "return a valid Dataset for a known input" in {
183+
184+
val sqlCtx = sqlContext
185+
import sqlCtx.implicits._
186+
187+
// Given
188+
val input = Seq(
189+
FlatEvent("Patient_A", 1, makeTS(1950, 1, 1), Some(makeTS(2009, 7, 11)), "molecule",
190+
"PIOGLITAZONE", 900.0, makeTS(2006, 1, 1), None),
191+
FlatEvent("Patient_A", 1, makeTS(1950, 1, 1), Some(makeTS(2009, 7, 11)), "molecule",
192+
"PIOGLITAZONE", 900.0, makeTS(2007, 2, 1), None),
193+
FlatEvent("Patient_A", 1, makeTS(1950, 1, 1), Some(makeTS(2009, 7, 11)), "molecule",
194+
"PIOGLITAZONE", 900.0, makeTS(2007, 5, 1), None),
195+
FlatEvent("Patient_B", 1, makeTS(1940, 1, 1), None, "molecule",
196+
"PIOGLITAZONE", 900.0, makeTS(2006, 1, 1), None),
197+
FlatEvent("Patient_B", 1, makeTS(1940, 1, 1), None, "molecule",
198+
"PIOGLITAZONE", 900.0, makeTS(2006, 5, 1), None),
199+
FlatEvent("Patient_C", 1, makeTS(1940, 1, 1), None, "molecule",
200+
"PIOGLITAZONE", 900.0, makeTS(2007, 8, 1), None)
201+
).toDS
202+
203+
val expected = Seq(
204+
FlatEvent("Patient_A", 1, makeTS(1950, 1, 1), Some(makeTS(2009, 7, 11)), "exposure",
205+
"PIOGLITAZONE", 1.0, makeTS(2006, 1, 1), Some(makeTS(2006, 1, 1))),
206+
FlatEvent("Patient_A", 1, makeTS(1950, 1, 1), Some(makeTS(2009, 7, 11)), "exposure",
207+
"PIOGLITAZONE", 1.0, makeTS(2007, 2, 1), Some(makeTS(2007, 2, 1))),
208+
FlatEvent("Patient_A", 1, makeTS(1950, 1, 1), Some(makeTS(2009, 7, 11)), "exposure",
209+
"PIOGLITAZONE", 1.0, makeTS(2007, 5, 1), Some(makeTS(2007, 5, 1))),
210+
FlatEvent("Patient_B", 1, makeTS(1940, 1, 1), None, "exposure",
211+
"PIOGLITAZONE", 1.0, makeTS(2006, 1, 1), Some(makeTS(2006, 1, 1))),
212+
FlatEvent("Patient_B", 1, makeTS(1940, 1, 1), None, "exposure",
213+
"PIOGLITAZONE", 1.0, makeTS(2006, 5, 1), Some(makeTS(2006, 5, 1)))
214+
).toDS.toDF
215+
216+
// When
217+
val result = MLPPExposuresTransformer.transform(input)
218+
219+
// Then
220+
result.show
221+
expected.show
222+
import RichDataFrames._
223+
assert(result.toDF === expected)
224+
}
225+
}

0 commit comments

Comments
 (0)