Skip to content

Commit cf321d7

Browse files
prestonphimarios
authored andcommitted
Add column method 'isNaN' (#313)
* Add column method 'isNaN' * Add a type class to restrict types that can be NaN
1 parent 3ad68b3 commit cf321d7

3 files changed

Lines changed: 58 additions & 1 deletion

File tree

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package frameless
2+
3+
import scala.annotation.implicitNotFound
4+
5+
/** Spark does NaN check only for these types */
6+
@implicitNotFound("Columns of type ${A} cannot be NaN.")
7+
trait CatalystNaN[A]
8+
9+
object CatalystNaN {
10+
private[this] val theInstance = new CatalystNaN[Any] {}
11+
private[this] def of[A]: CatalystNaN[A] = theInstance.asInstanceOf[CatalystNaN[A]]
12+
13+
implicit val framelessFloatNaN : CatalystNaN[Float] = of[Float]
14+
implicit val framelessDoubleNaN : CatalystNaN[Double] = of[Double]
15+
}
16+

dataset/src/main/scala/frameless/TypedColumn.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ abstract class AbstractTypedColumn[T, U]
139139
def isNotNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] =
140140
typed(Not(equalsTo(lit(None.asInstanceOf[U])).expr))
141141

142+
/** True if the current expression is a fractional number and is not NaN.
143+
*
144+
* apache/spark
145+
*/
146+
def isNaN(implicit n: CatalystNaN[U]): ThisType[T, Boolean] =
147+
typed(self.untyped.isNaN)
148+
142149
/** Convert an Optional column by providing a default value
143150
* {{{
144151
* df( df('opt).getOrElse(df('defaultValue)) )

dataset/src/test/scala/frameless/NumericTests.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package frameless
22

3-
import org.scalacheck.Prop
3+
import org.apache.spark.sql.Encoder
4+
import org.scalacheck.{Arbitrary, Gen, Prop}
45
import org.scalacheck.Prop._
6+
import org.scalatest.Matchers._
7+
58
import scala.reflect.ClassTag
69

710
class NumericTests extends TypedDatasetSuite {
@@ -169,4 +172,35 @@ class NumericTests extends TypedDatasetSuite {
169172
check(prop[Short ] _)
170173
check(prop[BigDecimal] _)
171174
}
175+
176+
test("isNaN") {
177+
val spark = session
178+
import spark.implicits._
179+
180+
implicit val doubleWithNaN = Arbitrary {
181+
implicitly[Arbitrary[Double]].arbitrary.flatMap(Gen.oneOf(_, Double.NaN))
182+
}
183+
implicit val x1 = Arbitrary{ doubleWithNaN.arbitrary.map(X1(_)) }
184+
185+
def prop[A : TypedEncoder : Encoder : CatalystNaN](data: List[X1[A]]): Prop = {
186+
val ds = TypedDataset.create(data)
187+
188+
val expected = ds.toDF().filter(!$"a".isNaN).map(_.getAs[A](0)).collect().toSeq
189+
val rs = ds.filter(!ds('a).isNaN).collect().run().map(_.a)
190+
191+
rs ?= expected
192+
}
193+
194+
check(forAll(prop[Float] _))
195+
check(forAll(prop[Double] _))
196+
}
197+
198+
test("isNaN with non-nan types should not compile") {
199+
val ds = TypedDataset.create((1, false, 'a, "b") :: Nil)
200+
201+
"ds.filter(ds('_1).isNaN)" shouldNot typeCheck
202+
"ds.filter(ds('_2).isNaN)" shouldNot typeCheck
203+
"ds.filter(ds('_3).isNaN)" shouldNot typeCheck
204+
"ds.filter(ds('_4).isNaN)" shouldNot typeCheck
205+
}
172206
}

0 commit comments

Comments
 (0)