Spark join
在Spark
生态里,join
分为Spark SQL
的join
和基于dataframe
的join
,这次我们来谈谈最常用的基于dataframe
的join
方法详解示例。
Spark支持所有类型的Join[1],包括:
- inner join
- left outer join
- right outer join
- full outer join
- left semi join
- left anti join
- cross join
下面分别阐述这几种Join的例子。
数据准备
运行环境
首先创建一下运行的类,本文将使用scalatest
测试用例来做。[2]
|
基本变量
然后加载成员变量spark
和sc
|
日志级别
如果觉得有需要的话,可以调高spark
的默认日志级别:
|
创建数据
- 创建第一张表的数据
cid
代表顾客编号(customID)
,name
代表顾客的姓名
。[2:1]
1号harbor
2号吴老师(mr.wu)
3号八宝粥(babaozhou)
|
-
第二张表
oid
代表订单的ID
,cid
代表顾客编号(customID)
,amount
代表此订单的金额
。[2:2]
|
这两张表的数据:
- 顾客表中,顾客3:babaozhou,没有产生订单;
- 订单表中,订单5:顾客id1000,在顾客表中没有此人。
示例
inner join
inner join 是我们一般比较熟悉也经常用的一种方式。
inner join 相当于是要把两个集合中重叠的部分作为结果,要求是此id在左侧出现了,同时也在右侧出现了。
|
cross join
cross join 会生成笛卡尔积
,也就是说,左边的每一项和右边的每一项都一一连接,简单来说就是:我们的顾客表有3个人,订单表有5个订单,那么应用 cross join 最终就会有15条数据。
|
outer join
在spark
里面,outer
, full,
fullouter
, full_outer
都代表的是outer
join。
outer join 就是两个集合的全集,包括两个集合key
对不上的null
数据。简单来说就是:相比内连接多了null
数据。
|
left join
在spark
里面,left
, leftouter
, left_outer
都代表的是left
join。
left outer join 相比内连接多了左边key
不为null
的null
数据。
下面代码是 customer join order 的例子:
有一个在订单表
中找不到对应订单的人也出现了。
|
下面代码是 order join customer 的例子:
有一个在顾客表
中找不到对应人的订单也出现了。
|
right join
rightouter
, right
, right_outer
都指的是同一种join
方式。和left
join 很像,right
join 只是顺序颠倒过来,很好理解。
下面代码是 customer join order 的例子:
有一个在顾客表
中找不到对应人的订单也出现了。
|
下面代码是 order join customer 的例子:
有一个在订单表
中找不到对应订单的人也出现了。
|
lefi semi join
左半连接,又可以叫leftsemi
, left_semi
。正如其名字暗示的一样,他只会按照条件和顺序去 join
,然后凭 join
结果只去左边表的列,并不会把两张表合并到一起,相当于右边的表只是作为判定使用,并不会掺和进结果中。
下面代码是 order join customer 的例子:
最终结果不会出现两边的空数据,顾客ID
为1000的那个订单5
就不会出现在最终的结果里。
|
下面代码是 customer join order 的例子:
最终结果不会出现两边的空数据,八宝粥(babaozhou)同学没有产生过订单,那么就不会出现在最终的结果里。
|
left anti join
leftanti
和left_anti
是同样的意思,相当于两个集合相减。
下面代码是 customer join order 的例子:
最终结果只出现两边的空数据,八宝粥(babaozhou)同学没有产生过订单,那么就会出现在最终的结果里。也就是说:
|
最终的结果只剩下八宝粥(babaozhou)同学。
|
下面代码是 orders join customers 的例子:
最终结果只出现两边的空数据,顾客ID
为1000的那个订单5
就会出现在最终的结果里。也就是说:
|
最终的结果只剩下顾客ID
为1000的那个订单5
。
|
总结
Spark
里面join
方法,种类丰富,能满足日常遇到的几乎一切问题。
但是由于其是shuffle
操作,是巨大的计算开销,所以在使用的时候也需要关注性能优化,比如:
- 设置参数
spark.sql.autoBroadcastJoinThreshold
的值(默认10M)[1:1],让小表被广播到每个节点,以降低大表的 I/O 开销,提升性能。 spark
默认sort merge join
,可以看情况创造条件使用hash join
提升性能。
完整的代码可以再这里找到:https://gist.github.com/c0f8a928dfe75aaf85cf9aab3708593a
import org.apache.log4j.{Level, Logger} | |
import org.apache.spark.sql.{Row, SparkSession} | |
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, StructField, StructType} | |
import org.scalatest.{FlatSpec, Matchers} | |
class JoinSpec extends FlatSpec with Matchers { | |
@transient lazy val logger: Logger = Logger.getLogger(getClass) | |
Logger.getRootLogger.setLevel(Level.INFO) | |
Logger.getLogger("org").setLevel(Level.WARN) | |
Logger.getLogger("akka").setLevel(Level.WARN) | |
lazy val spark: SparkSession = { | |
SparkSession | |
.builder() | |
.master("local") | |
.appName("spark test of join") | |
.getOrCreate() | |
} | |
val sc = spark.sparkContext | |
private val customers = spark.createDataFrame( | |
sc.parallelize(Seq( | |
Row(1, "harbor"), | |
Row(2, "mr.wu"), | |
Row(3, "babaozhou") | |
)), | |
schema = StructType( | |
Array[StructField]( | |
StructField("cid", IntegerType), | |
StructField("name", StringType) | |
)) | |
) | |
private val orders = spark.createDataFrame( | |
sc.parallelize(Seq( | |
Row(1, 1, 50.0d), | |
Row(2, 2, 10d), | |
Row(3, 2, 10d), | |
Row(4, 2, 10d), | |
Row(5, 1000, 19d) | |
)), schema = StructType( | |
Array[StructField]( | |
StructField("oid", IntegerType), | |
StructField("cid", IntegerType), | |
StructField("amount", DoubleType) | |
)) | |
) | |
"SparkJoin inner" should "collect only matching rows from both sides" in { | |
val innerJoinResultDF = orders.join(customers, Seq("cid"), joinType = "inner") | |
innerJoinResultDF.show() | |
val innerJoinResult = innerJoinResultDF.collect() | |
innerJoinResult should (have size 4 and contain allOf( | |
// |cid|oid|amount| name| | |
Row(1, 1, 50.0d, "harbor"), | |
Row(2, 2, 10d, "mr.wu"), | |
Row(2, 3, 10d, "mr.wu"), | |
Row(2, 4, 10d, "mr.wu") | |
)) | |
} | |
"SparkJoin cross" should "create a Cartesian Product" in { | |
val crossJoinResultDF = orders.crossJoin(customers) | |
crossJoinResultDF.show() | |
val crossJoinResult = crossJoinResultDF.collect() | |
crossJoinResult should (have size 15 and contain allOf( | |
// |oid| cid| amount| cid| name| | |
Row(1, 1, 50.0, 1, "harbor"), | |
Row(1, 1, 50.0, 2, "mr.wu"), | |
Row(1, 1, 50.0, 3, "babaozhou"), | |
Row(2, 2, 10.0, 1, "harbor"), | |
Row(2, 2, 10.0, 2, "mr.wu"), | |
Row(2, 2, 10.0, 3, "babaozhou"), | |
Row(3, 2, 10.0, 1, "harbor"), | |
Row(3, 2, 10.0, 2, "mr.wu"), | |
Row(3, 2, 10.0, 3, "babaozhou"), | |
Row(4, 2, 10.0, 1, "harbor"), | |
Row(4, 2, 10.0, 2, "mr.wu"), | |
Row(4, 2, 10.0, 3, "babaozhou"), | |
Row(5, 1000, 19.0, 1, "harbor"), | |
Row(5, 1000, 19.0, 2, "mr.wu"), | |
Row(5, 1000, 19.0, 3, "babaozhou") | |
)) | |
} | |
"SparkJoin outer" should "full outer join 相比内连接多了null数据" in { | |
// "outer", "full", "fullouter", "full_outer" | |
val outerJoinResultDF = orders.join(customers, Seq("cid"), joinType = "outer") | |
outerJoinResultDF.show() | |
val outerJoinResult = outerJoinResultDF.collect() | |
outerJoinResult should (have size 6 and contain allOf( | |
// | cid| oid| amount| name| | |
Row( 1, 1, 50.0d, "harbor"), | |
Row( 2, 2, 10d, "mr.wu"), | |
Row( 2, 3, 10d, "mr.wu"), | |
Row( 2, 4, 10d, "mr.wu"), | |
Row( 3, null, null, "babaozhou"), | |
Row( 1000, 5, 19d, null) | |
)) | |
} | |
"SparkJoin left order join customer" should "left outer join 相比内连接多了左边key不为null的null数据" in { | |
// "leftouter", "left", "left_outer" | |
val leftJoinResultDF = orders.join(customers, Seq("cid"), joinType = "left") | |
leftJoinResultDF.show() | |
val leftJoinResult = leftJoinResultDF.collect() | |
leftJoinResult should (have size 5 and contain allOf( | |
// | cid| oid| amount| name| | |
Row( 1, 1, 50.0d, "harbor"), | |
Row( 2, 2, 10d, "mr.wu"), | |
Row( 2, 3, 10d, "mr.wu"), | |
Row( 2, 4, 10d, "mr.wu"), | |
Row( 1000, 5, 19d, null) | |
)) | |
} | |
"SparkJoin left customer join order" should "left outer join 相比内连接多了左边key不为null的null数据" in { | |
// "leftouter", "left", "left_outer" | |
val leftJoinResultDF = customers.join(orders, Seq("cid"), joinType = "left") | |
leftJoinResultDF.show() | |
val leftJoinResult = leftJoinResultDF.collect() | |
leftJoinResult should (have size 5 and contain allOf( | |
// |cid| name| oid| amount| | |
Row(1, "harbor", 1, 50.0d), | |
Row(2, "mr.wu", 2, 10d), | |
Row(2, "mr.wu", 3, 10d), | |
Row(2, "mr.wu", 4, 10d), | |
Row(3, "babaozhou", null, null) | |
)) | |
} | |
"SparkJoin right customer join order" should "和left order join customer除了列数据顺序不一样之外其他都一样" in { | |
// "rightouter", "right", "right_outer" | |
val rightJoinResultDF = customers.join(orders, Seq("cid"), joinType = "right") | |
rightJoinResultDF.show() | |
val rightJoinResult = rightJoinResultDF.collect() | |
rightJoinResult should (have size 5 and contain allOf( | |
// | cid| name| oid| amount| | |
Row( 1, "harbor", 1, 50.0d), | |
Row( 2, "mr.wu", 2, 10d), | |
Row( 2, "mr.wu", 3, 10d), | |
Row( 2, "mr.wu", 4, 10d), | |
Row( 1000, null, 5, 19d) | |
)) | |
} | |
"SparkJoin right order join customer" should "和left customer join order除了列数据顺序不一样之外其他都一样" in { | |
// "rightouter", "right", "right_outer" | |
val rightJoinResultDF = orders.join(customers, Seq("cid"), joinType = "right") | |
rightJoinResultDF.show() | |
val rightJoinResult = rightJoinResultDF.collect() | |
rightJoinResult should (have size 5 and contain allOf( | |
// |cid| oid| amount| name| | |
Row(1, 1, 50.0d, "harbor"), | |
Row(2, 2, 10d, "mr.wu"), | |
Row(2, 3, 10d, "mr.wu"), | |
Row(2, 4, 10d, "mr.wu"), | |
Row(3, null, null, "babaozhou") | |
)) | |
} | |
"SparkJoin left_semi order join customer" should "和inner除了少了右边独有的列之外其他都一样" in { | |
// "leftsemi", "left_semi" | |
val leftSemiJoinResultDF = orders.join(customers, Seq("cid"), joinType = "left_semi") | |
leftSemiJoinResultDF.show() | |
val leftSemiJoinResult = leftSemiJoinResultDF.collect() | |
leftSemiJoinResult should (have size 4 and contain allOf( | |
// |cid|oid|amount| | |
Row(1, 1, 50.0d), | |
Row(2, 2, 10d), | |
Row(2, 3, 10d), | |
Row(2, 4, 10d) | |
)) | |
} | |
"SparkJoin left_semi custom join order" should "和inner除了少了右边独有的列还有去重了之外其他都一样" in { | |
// "leftsemi", "left_semi" | |
val leftSemiJoinResultDF = customers.join(orders, Seq("cid"), joinType = "left_semi") | |
leftSemiJoinResultDF.show() | |
val leftSemiJoinResult = leftSemiJoinResultDF.collect() | |
leftSemiJoinResult should (have size 2 and contain allOf( | |
// |cid |name | | |
Row(1, "harbor"), | |
Row(2, "mr.wu") | |
)) | |
} | |
"SparkJoin left_anti custom join order" should "和 left_semi custom join order 组合在一起就是完整的 customs 表" in { | |
// "leftanti", "left_anti" | |
val leftAntiJoinResultDF = customers.join(orders, Seq("cid"), joinType = "left_anti") | |
leftAntiJoinResultDF.show() | |
val leftAntiJoinResult = leftAntiJoinResultDF.collect() | |
leftAntiJoinResult should (have size 1 and contain ( | |
// |cid|name| | |
Row(3, "babaozhou") | |
)) | |
} | |
"SparkJoin left_anti order join customer" should "和 left_semi order join customer 组合在一起就是完整的 orders 表" in { | |
// "leftanti", "left_anti" | |
val leftAntiJoinResultDF = orders.join(customers, Seq("cid"), joinType = "left_anti") | |
leftAntiJoinResultDF.show() | |
val leftAntiJoinResult = leftAntiJoinResultDF.collect() | |
leftAntiJoinResult should (have size 1 and contain ( | |
// |cid|oid|amount| | |
Row(1000,5, 19d) | |
)) | |
} | |
} |
参考文献
Spark SQL 之 Join 实现[EB/OL]. 守护之鲨, [2019-11-04]. http://sharkdtu.com/posts/spark-sql-join.html. ↩︎ ↩︎
types in Spark SQL[EB/OL]. [2019-11-04]. https://www.waitingforcode.com/apache-spark-sql/join-types-spark-sql/read. ↩︎ ↩︎ ↩︎
Preview: