使用UDAF一般步骤:
自定义类继承UserDefinedAggregateFunction或者Aggregator,对每个阶段方法做实现在sparkSession中注册UDAF,为其绑定一个名字在sql语句中使用上面绑定的名字调用 继承UserDefinedAggregateFunction class UDAFStringCount extends UserDefinedAggregateFunction { // 输入数据的类型 override def inputSchema: StructType = { StructType(Array(StructField("str", StringType, true))) } // 中间聚合时所处理的数据 override def bufferSchema: StructType = { StructType(Array(StructField("count", IntegerType, true))) } // 函数返回的类型 override def dataType: DataType = { IntegerType } // 指定是否是确定性的 override def deterministic: Boolean = { true } // 为每个分组的数据执行初始化操作 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0 } // 每个分组有新值过来,如何进行分组对应的聚合值的计算 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getAs[Int](0) + 1 } // 合并,一个分组的数据会分布在多个节点上处理,所以最后要用merge override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0) } // 通过中间的缓存聚合值,最后返回一个最终的聚合值 override def evaluate(buffer: Row): Any = { buffer.getAs[Int](0) } } 注册udaf,sql调用 object UDAF { def main(args: Array[String]): Unit = { //UDAF可以针对多行输入,进行聚合计算,返回一个输出 val sparkSession = SparkSession.builder() .appName("UDAF") .master("local[2]") .getOrCreate() val names = Array("zhangsan", "lisi", "wangwu", "Tom", "Jerry", "zhangsan", "Tom", "zhangsan", "lisi", "wangwu", "Tom", "Jerry", "Alan") val namesRDD = sparkSession.sparkContext.parallelize(names) import sparkSession.implicits._ val namesDF = namesRDD.toDF("name") namesDF.createOrReplaceTempView("udafTest") sparkSession.udf.register("strCount", new UDAFStringCount) sparkSession.sql("select name, strCount(name) len from udafTest group by name").show() } } +--------+---+ | name|len| +--------+---+ | wangwu| 2| | Tom| 3| | Jerry| 2| |zhangsan| 3| | Alan| 1| | lisi| 2| +--------+---+ 继承Aggregator 在这里插入代码片