@@ -2,13 +2,18 @@ package frameless
22
33import org .scalacheck .Prop
44import org .scalacheck .Prop ._
5- import org .apache .spark .sql .{SparkSession , functions => sparkFunctions }
5+ import org .apache .spark .sql .{ SparkSession , functions => sparkFunctions }
66
77class SelfJoinTests extends TypedDatasetSuite {
8+
89 // Without crossJoin.enabled=true Spark doesn't like trivial join conditions:
910 // [error] Join condition is missing or trivial.
1011 // [error] Use the CROSS JOIN syntax to allow cartesian products between these relations.
11- def allowTrivialJoin [T ](body : => T )(implicit session : SparkSession ): T = {
12+ def allowTrivialJoin [T ](
13+ body : => T
14+ )(implicit
15+ session : SparkSession
16+ ): T = {
1217 val crossJoin = " spark.sql.crossJoin.enabled"
1318 val oldSetting = session.conf.get(crossJoin)
1419 session.conf.set(crossJoin, " true" )
@@ -17,7 +22,11 @@ class SelfJoinTests extends TypedDatasetSuite {
1722 result
1823 }
1924
20- def allowAmbiguousJoin [T ](body : => T )(implicit session : SparkSession ): T = {
25+ def allowAmbiguousJoin [T ](
26+ body : => T
27+ )(implicit
28+ session : SparkSession
29+ ): T = {
2130 val crossJoin = " spark.sql.analyzer.failAmbiguousSelfJoin"
2231 val oldSetting = session.conf.get(crossJoin)
2332 session.conf.set(crossJoin, " false" )
@@ -27,113 +36,177 @@ class SelfJoinTests extends TypedDatasetSuite {
2736 }
2837
2938 test(" self join with colLeft/colRight disambiguation" ) {
30- def prop [
31- A : TypedEncoder : Ordering ,
32- B : TypedEncoder : Ordering
33- ]( dx : List [ X2 [ A , B ]], d : X2 [ A , B ] ): Prop = allowAmbiguousJoin {
39+ def prop [A : TypedEncoder : Ordering , B : TypedEncoder : Ordering ](
40+ dx : List [ X2 [ A , B ]] ,
41+ d : X2 [ A , B ]
42+ ): Prop = allowAmbiguousJoin {
3443 val data = d :: dx
3544 val ds = TypedDataset .create(data)
3645
3746 // This is the way to write unambiguous self-join in vanilla, see https://goo.gl/XnkSUD
3847 val df1 = ds.dataset.as(" df1" )
3948 val df2 = ds.dataset.as(" df2" )
40- val vanilla = df1.join(df2,
41- sparkFunctions.col(" df1.a" ) === sparkFunctions.col(" df2.a" )).count()
49+ val vanilla = df1
50+ .join(df2, sparkFunctions.col(" df1.a" ) === sparkFunctions.col(" df2.a" ))
51+ .count()
4252
43- val typed = ds.joinInner(ds)(
44- ds.colLeft(' a ) === ds.colRight(' a )
45- ).count().run()
53+ val typed = ds
54+ .joinInner(ds)(
55+ ds.colLeft(' a ) === ds.colRight(' a )
56+ )
57+ .count()
58+ .run()
4659
4760 vanilla ?= typed
4861 }
4962
5063 check(prop[Int , Int ] _)
5164 }
5265
66+ test(" self join collects correct values via colLeft/colRight" ) {
67+ def prop [A : TypedEncoder : Ordering , B : TypedEncoder : Ordering ](
68+ dx : List [X2 [A , B ]],
69+ d : X2 [A , B ]
70+ ): Prop = allowAmbiguousJoin {
71+ val data = d :: dx
72+ val ds = TypedDataset .create(data)
73+
74+ // Collecting the joined tuples exercises the colLeft/colRight disambiguation and the
75+ // (T, U) ExpressionEncoder end to end, not just the row count: a regression guard for
76+ // Spark 4, where columns no longer wrap Catalyst expressions directly.
77+ val typed = ds
78+ .joinInner(ds)(ds.colLeft(' a ) === ds.colRight(' a ))
79+ .collect()
80+ .run()
81+ .toVector
82+ .sorted
83+
84+ val expected = (for {
85+ l <- data
86+ r <- data
87+ if l.a == r.a
88+ } yield (l, r)).toVector.sorted
89+
90+ typed ?= expected
91+ }
92+
93+ check(prop[Int , Int ] _)
94+ check(prop[String , Long ] _)
95+ }
96+
5397 test(" trivial self join" ) {
54- def prop [
55- A : TypedEncoder : Ordering ,
56- B : TypedEncoder : Ordering
57- ](dx : List [X2 [A , B ]], d : X2 [A , B ]): Prop =
58- allowTrivialJoin { allowAmbiguousJoin {
59-
60- val data = d :: dx
61- val ds = TypedDataset .create(data)
62- val untyped = ds.dataset
63- // Interestingly, even with aliasing it seems that it's impossible to
64- // obtain a trivial join condition of shape df1.a == df1.a, Spark we
65- // always interpret that as df1.a == df2.a. For the purpose of this
66- // test we fall-back to lit(true) instead.
67- // val trivial = sparkFunctions.col("df1.a") === sparkFunctions.col("df1.a")
68- val trivial = sparkFunctions.lit(true )
69- val vanilla = untyped.as(" df1" ).join(untyped.as(" df2" ), trivial).count()
70-
71- val typed = ds.joinInner(ds)(ds.colLeft(' a ) === ds.colLeft(' a )).count().run
72- vanilla ?= typed
73- } }
98+ def prop [A : TypedEncoder : Ordering , B : TypedEncoder : Ordering ](
99+ dx : List [X2 [A , B ]],
100+ d : X2 [A , B ]
101+ ): Prop =
102+ allowTrivialJoin {
103+ allowAmbiguousJoin {
104+
105+ val data = d :: dx
106+ val ds = TypedDataset .create(data)
107+ val untyped = ds.dataset
108+ // Interestingly, even with aliasing it seems that it's impossible to
109+ // obtain a trivial join condition of shape df1.a == df1.a, Spark we
110+ // always interpret that as df1.a == df2.a. For the purpose of this
111+ // test we fall-back to lit(true) instead.
112+ // val trivial = sparkFunctions.col("df1.a") === sparkFunctions.col("df1.a")
113+ val trivial = sparkFunctions.lit(true )
114+ val vanilla =
115+ untyped.as(" df1" ).join(untyped.as(" df2" ), trivial).count()
116+
117+ val typed =
118+ ds.joinInner(ds)(ds.colLeft(' a ) === ds.colLeft(' a )).count().run
119+ vanilla ?= typed
120+ }
121+ }
74122
75123 check(prop[Int , Int ] _)
76124 }
77125
78126 test(" self join with unambiguous expression" ) {
79127 def prop [
80- A : TypedEncoder : CatalystNumeric : Ordering ,
81- B : TypedEncoder : Ordering
82- ](data : List [X3 [A , A , B ]]): Prop = allowAmbiguousJoin {
128+ A : TypedEncoder : CatalystNumeric : Ordering ,
129+ B : TypedEncoder : Ordering
130+ ](data : List [X3 [A , A , B ]]
131+ ): Prop = allowAmbiguousJoin {
83132 val ds = TypedDataset .create(data)
84133
85134 val df1 = ds.dataset.alias(" df1" )
86135 val df2 = ds.dataset.alias(" df2" )
87136
88- val vanilla = df1.join(df2,
89- (sparkFunctions.col(" df1.a" ) + sparkFunctions.col(" df1.b" )) ===
90- (sparkFunctions.col(" df2.a" ) + sparkFunctions.col(" df2.b" ))).count()
91-
92- val typed = ds.joinInner(ds)(
93- (ds.colLeft(' a ) + ds.colLeft(' b )) === (ds.colRight(' a ) + ds.colRight(' b ))
94- ).count().run()
137+ val vanilla = df1
138+ .join(
139+ df2,
140+ (sparkFunctions.col(" df1.a" ) + sparkFunctions.col(" df1.b" )) ===
141+ (sparkFunctions.col(" df2.a" ) + sparkFunctions.col(" df2.b" ))
142+ )
143+ .count()
144+
145+ val typed = ds
146+ .joinInner(ds)(
147+ (ds.colLeft(' a ) + ds.colLeft(' b )) === (ds.colRight(' a ) + ds
148+ .colRight(' b ))
149+ )
150+ .count()
151+ .run()
95152
96153 vanilla ?= typed
97154 }
98155
99156 check(prop[Int , Int ] _)
100157 }
101158
102- test(" Do you want ambiguous self join? This is how you get ambiguous self join." ) {
159+ test(
160+ " Do you want ambiguous self join? This is how you get ambiguous self join."
161+ ) {
103162 def prop [
104- A : TypedEncoder : CatalystNumeric : Ordering ,
105- B : TypedEncoder : Ordering
106- ](data : List [X3 [A , A , B ]]): Prop =
107- allowTrivialJoin { allowAmbiguousJoin {
108- val ds = TypedDataset .create(data)
109-
110- // The point I'm making here is that it "behaves just like Spark". I
111- // don't know (or really care about how) how Spark disambiguates that
112- // internally...
113- val vanilla = ds.dataset.join(ds.dataset,
114- (ds.dataset(" a" ) + ds.dataset(" b" )) ===
115- (ds.dataset(" a" ) + ds.dataset(" b" ))).count()
116-
117- val typed = ds.joinInner(ds)(
118- (ds.col(' a ) + ds.col(' b )) === (ds.col(' a ) + ds.col(' b ))
119- ).count().run()
120-
121- vanilla ?= typed
122- } }
123-
124- check(prop[Int , Int ] _)
125- }
163+ A : TypedEncoder : CatalystNumeric : Ordering ,
164+ B : TypedEncoder : Ordering
165+ ](data : List [X3 [A , A , B ]]
166+ ): Prop =
167+ allowTrivialJoin {
168+ allowAmbiguousJoin {
169+ val ds = TypedDataset .create(data)
170+
171+ // The point I'm making here is that it "behaves just like Spark". I
172+ // don't know (or really care about how) how Spark disambiguates that
173+ // internally...
174+ val vanilla = ds.dataset
175+ .join(
176+ ds.dataset,
177+ (ds.dataset(" a" ) + ds.dataset(" b" )) ===
178+ (ds.dataset(" a" ) + ds.dataset(" b" ))
179+ )
180+ .count()
181+
182+ val typed = ds
183+ .joinInner(ds)(
184+ (ds.col(' a ) + ds.col(' b )) === (ds.col(' a ) + ds.col(' b ))
185+ )
186+ .count()
187+ .run()
188+
189+ vanilla ?= typed
190+ }
191+ }
192+
193+ check(prop[Int , Int ] _)
194+ }
126195
127196 test(" colLeft and colRight are equivalent to col outside of joins" ) {
128- def prop [A , B , C , D ](data : Vector [X4 [A , B , C , D ]])(
129- implicit
130- ea : TypedEncoder [A ],
131- ex4 : TypedEncoder [X4 [A , B , C , D ]]
132- ): Prop = {
197+ def prop [A , B , C , D ](
198+ data : Vector [X4 [A , B , C , D ]]
199+ )(implicit
200+ ea : TypedEncoder [A ],
201+ ex4 : TypedEncoder [X4 [A , B , C , D ]]
202+ ): Prop = {
133203 val dataset = TypedDataset .create(data)
134- val selectedCol = dataset.select(dataset.col [A ](' a )).collect().run().toVector
135- val selectedColLeft = dataset.select(dataset.colLeft [A ](' a )).collect().run().toVector
136- val selectedColRight = dataset.select(dataset.colRight[A ](' a )).collect().run().toVector
204+ val selectedCol =
205+ dataset.select(dataset.col[A ](' a )).collect().run().toVector
206+ val selectedColLeft =
207+ dataset.select(dataset.colLeft[A ](' a )).collect().run().toVector
208+ val selectedColRight =
209+ dataset.select(dataset.colRight[A ](' a )).collect().run().toVector
137210
138211 (selectedCol ?= selectedColLeft) && (selectedCol ?= selectedColRight)
139212 }
@@ -145,16 +218,26 @@ class SelfJoinTests extends TypedDatasetSuite {
145218 }
146219
147220 test(" colLeft and colRight are equivalent to col outside of joins - via files (codegen)" ) {
148- def prop [A , B , C , D ](data : Vector [X4 [A , B , C , D ]])(
149- implicit
150- ea : TypedEncoder [A ],
151- ex4 : TypedEncoder [X4 [A , B , C , D ]]
152- ): Prop = {
153- TypedDataset .create(data).write.mode(" overwrite" ).parquet(" ./target/testData" )
154- val dataset = TypedDataset .createUnsafe[X4 [A , B , C , D ]](session.read.parquet(" ./target/testData" ))
155- val selectedCol = dataset.select(dataset.col [A ](' a )).collect().run().toVector
156- val selectedColLeft = dataset.select(dataset.colLeft [A ](' a )).collect().run().toVector
157- val selectedColRight = dataset.select(dataset.colRight[A ](' a )).collect().run().toVector
221+ def prop [A , B , C , D ](
222+ data : Vector [X4 [A , B , C , D ]]
223+ )(implicit
224+ ea : TypedEncoder [A ],
225+ ex4 : TypedEncoder [X4 [A , B , C , D ]]
226+ ): Prop = {
227+ TypedDataset
228+ .create(data)
229+ .write
230+ .mode(" overwrite" )
231+ .parquet(" ./target/testData" )
232+ val dataset = TypedDataset .createUnsafe[X4 [A , B , C , D ]](
233+ session.read.parquet(" ./target/testData" )
234+ )
235+ val selectedCol =
236+ dataset.select(dataset.col[A ](' a )).collect().run().toVector
237+ val selectedColLeft =
238+ dataset.select(dataset.colLeft[A ](' a )).collect().run().toVector
239+ val selectedColRight =
240+ dataset.select(dataset.colRight[A ](' a )).collect().run().toVector
158241
159242 (selectedCol ?= selectedColLeft) && (selectedCol ?= selectedColRight)
160243 }
0 commit comments