spark SQL UDF和UDAF

it2022-12-30  65

UDF

UDF(User Define Function) spark内置的函数不能满足我们的要求的时候,我们通常需要自定义函数来实现我们的需求。示例 object UDF { def main(args: Array[String]): Unit = { val sparkSession = SparkSession.builder() .appName("UDF") .master("local[2]") .getOrCreate() // 创建一个RDD val names = Array("zhangsan", "lisi", "wangwu", "Tom", "Jerry", "Alan") val namesRDD = sparkSession.sparkContext.parallelize(names) // 转为DF // 方法一:动态加载 val namesRowRDD = namesRDD.map(x => Row(x)) val schema = StructType(Array( StructField("name", StringType, true) )) val namesDF = sparkSession.createDataFrame(namesRowRDD, schema) // 方法二:反射 import sparkSession.implicits._ val namesDF = namesRDD.toDF("name") namesDF.createOrReplaceTempView("udfTest") // 注册UDF函数 sparkSession.udf.register("strLength", (str:String) =>str.length) sparkSession.sql("select name, strLength(name) length from udfTest").show() } } +--------+------+ | name|length| +--------+------+ |zhangsan| 8| | lisi| 4| | wangwu| 6| | Tom| 3| | Jerry| 5| | Alan| 4| +--------+------+

UDAF

UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是:普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值按指定方法聚合一下。UDAF的使用有两种方式:1.继承UserDefinedAggregateFunction 2.继承Aggregator(优点是可以带类型)

使用UDAF一般步骤:

自定义类继承UserDefinedAggregateFunction或者Aggregator,对每个阶段方法做实现在sparkSession中注册UDAF,为其绑定一个名字在sql语句中使用上面绑定的名字调用 继承UserDefinedAggregateFunction class UDAFStringCount extends UserDefinedAggregateFunction { // 输入数据的类型 override def inputSchema: StructType = { StructType(Array(StructField("str", StringType, true))) } // 中间聚合时所处理的数据 override def bufferSchema: StructType = { StructType(Array(StructField("count", IntegerType, true))) } // 函数返回的类型 override def dataType: DataType = { IntegerType } // 指定是否是确定性的 override def deterministic: Boolean = { true } // 为每个分组的数据执行初始化操作 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0 } // 每个分组有新值过来,如何进行分组对应的聚合值的计算 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getAs[Int](0) + 1 } // 合并,一个分组的数据会分布在多个节点上处理,所以最后要用merge override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0) } // 通过中间的缓存聚合值,最后返回一个最终的聚合值 override def evaluate(buffer: Row): Any = { buffer.getAs[Int](0) } } 注册udaf,sql调用 object UDAF { def main(args: Array[String]): Unit = { //UDAF可以针对多行输入,进行聚合计算,返回一个输出 val sparkSession = SparkSession.builder() .appName("UDAF") .master("local[2]") .getOrCreate() val names = Array("zhangsan", "lisi", "wangwu", "Tom", "Jerry", "zhangsan", "Tom", "zhangsan", "lisi", "wangwu", "Tom", "Jerry", "Alan") val namesRDD = sparkSession.sparkContext.parallelize(names) import sparkSession.implicits._ val namesDF = namesRDD.toDF("name") namesDF.createOrReplaceTempView("udafTest") sparkSession.udf.register("strCount", new UDAFStringCount) sparkSession.sql("select name, strCount(name) len from udafTest group by name").show() } } +--------+---+ | name|len| +--------+---+ | wangwu| 2| | Tom| 3| | Jerry| 2| |zhangsan| 3| | Alan| 1| | lisi| 2| +--------+---+ 继承Aggregator 在这里插入代码片
最新回复(0)