|
1 | 1 | package frameless |
2 | 2 |
|
3 | | -import org.scalacheck.Prop |
| 3 | +import org.apache.spark.sql.Encoder |
| 4 | +import org.scalacheck.{Arbitrary, Gen, Prop} |
4 | 5 | import org.scalacheck.Prop._ |
| 6 | +import org.scalatest.Matchers._ |
| 7 | + |
5 | 8 | import scala.reflect.ClassTag |
6 | 9 |
|
7 | 10 | class NumericTests extends TypedDatasetSuite { |
@@ -169,4 +172,35 @@ class NumericTests extends TypedDatasetSuite { |
169 | 172 | check(prop[Short ] _) |
170 | 173 | check(prop[BigDecimal] _) |
171 | 174 | } |
| 175 | + |
| 176 | + test("isNaN") { |
| 177 | + val spark = session |
| 178 | + import spark.implicits._ |
| 179 | + |
| 180 | + implicit val doubleWithNaN = Arbitrary { |
| 181 | + implicitly[Arbitrary[Double]].arbitrary.flatMap(Gen.oneOf(_, Double.NaN)) |
| 182 | + } |
| 183 | + implicit val x1 = Arbitrary{ doubleWithNaN.arbitrary.map(X1(_)) } |
| 184 | + |
| 185 | + def prop[A : TypedEncoder : Encoder : CatalystNaN](data: List[X1[A]]): Prop = { |
| 186 | + val ds = TypedDataset.create(data) |
| 187 | + |
| 188 | + val expected = ds.toDF().filter(!$"a".isNaN).map(_.getAs[A](0)).collect().toSeq |
| 189 | + val rs = ds.filter(!ds('a).isNaN).collect().run().map(_.a) |
| 190 | + |
| 191 | + rs ?= expected |
| 192 | + } |
| 193 | + |
| 194 | + check(forAll(prop[Float] _)) |
| 195 | + check(forAll(prop[Double] _)) |
| 196 | + } |
| 197 | + |
| 198 | + test("isNaN with non-nan types should not compile") { |
| 199 | + val ds = TypedDataset.create((1, false, 'a, "b") :: Nil) |
| 200 | + |
| 201 | + "ds.filter(ds('_1).isNaN)" shouldNot typeCheck |
| 202 | + "ds.filter(ds('_2).isNaN)" shouldNot typeCheck |
| 203 | + "ds.filter(ds('_3).isNaN)" shouldNot typeCheck |
| 204 | + "ds.filter(ds('_4).isNaN)" shouldNot typeCheck |
| 205 | + } |
172 | 206 | } |
0 commit comments