Spark SQL的自定义函数UDF

    科技2023-10-20  97

    Spark SQL的自定义函数UDF

    1. 背景

    在SQL使用时,会有内置函数,但如果业务比较复杂,但又希望可以有更加灵活的函数使用和复用,则需要自定义UDF,就是user defined function,可以分为UDTF、UDAFUDTF,user defined table-generating function,就是将数据打散UDAF,user defined aggregating function,就是将数据聚合。

    2. 创建和使用UDF

    下述会使用案例来展示如何使用自定义函数

    2.1 自定义函数,将GPS经纬度解析为位置信息

    环境准备 高德地图的appkeyidea 2020maven 3.6.3scala 2.12.12spark 3.0.1pom文件 <!-- 定义了一些常量 --> <properties> <maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.target>1.8</maven.compiler.target> <scala.version>2.12.10</scala.version> <spark.version>3.0.1</spark.version> <hbase.version>2.2.5</hbase.version> <hadoop.version>3.2.1</hadoop.version> <encoding>UTF-8</encoding> </properties> <dependencies> <!-- 导入scala的依赖 --> <dependency> <groupId>org.scala-lang</groupId> <artifactId>scala-library</artifactId> <version>${scala.version}</version> <!-- 编译时会引入依赖,打包是不引入依赖 --> <!-- <scope>provided</scope>--> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.httpcomponents/httpclient --> <dependency> <groupId>org.apache.httpcomponents</groupId> <artifactId>httpclient</artifactId> <version>4.5.12</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.12</artifactId> <version>${spark.version}</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.12</artifactId> <version>${spark.version}</version> <!-- 编译时会引入依赖,打包是不引入依赖 --> <!-- <scope>provided</scope>--> </dependency> <!-- https://mvnrepository.com/artifact/com.alibaba/fastjson --> <dependency> <groupId>com.alibaba</groupId> <artifactId>fastjson</artifactId> <version>1.2.73</version> </dependency> <!-- https://mvnrepository.com/artifact/mysql/mysql-connector-java --> <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> <version>5.1.47</version> </dependency> </dependencies> <build> <pluginManagement> <plugins> <!-- 编译scala的插件 --> <plugin> <groupId>net.alchim31.maven</groupId> <artifactId>scala-maven-plugin</artifactId> <version>3.2.2</version> </plugin> <!-- 编译java的插件 --> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <version>3.5.1</version> </plugin> </plugins> </pluginManagement> <plugins> <plugin> <groupId>net.alchim31.maven</groupId> <artifactId>scala-maven-plugin</artifactId> <executions> <execution> <id>scala-compile-first</id> <phase>process-resources</phase> <goals> <goal>add-source</goal> <goal>compile</goal> </goals> </execution> <execution> <id>scala-test-compile</id> <phase>process-test-resources</phase> <goals> <goal>testCompile</goal> </goals> </execution> </executions> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <executions> <execution> <phase>compile</phase> <goals> <goal>compile</goal> </goals> </execution> </executions> </plugin> <!-- 打jar插件 --> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> <version>2.4.3</version> <executions> <execution> <phase>package</phase> <goals> <goal>shade</goal> </goals> <configuration> <filters> <filter> <artifact>*:*</artifact> <excludes> <exclude>META-INF/*.SF</exclude> <exclude>META-INF/*.DSA</exclude> <exclude>META-INF/*.RSA</exclude> </excludes> </filter> </filters> </configuration> </execution> </executions> </plugin> </plugins> </build> object GeoFunc { // 根据经纬度返回省和市信息 val geo = (longitude:Double, latitude: Double) => { val httpClient: CloseableHttpClient = HttpClients.createDefault() // 构建请求参数 val httpGet = new HttpGet(s"https://restapi.amap.com/v3/geocode/regeo?&location=$longitude,$latitude&key=71cc7d9df22483b27ec40ecb45d9d87b") // 发送请求,获取返回信息 val response: CloseableHttpResponse = httpClient.execute(httpGet) var province:String = null var city:String = null try { // 将返回对象中数据提取出来 val entity: HttpEntity = response.getEntity if (response.getStatusLine.getStatusCode == 200) { // 将返回对象中数据转换为字符串 val resultStr: String = EntityUtils.toString(entity) // 解析返回的json字符串 val jSONObject: JSONObject = JSON.parseObject(resultStr) // 根据高德地图反地理编码接口返回数据中字段进行数据解析 val regeocode: JSONObject = jSONObject.getJSONObject("regeocode") if (regeocode != null && regeocode.isEmpty == false) { val address: JSONObject = regeocode.getJSONObject("addressComponent") province = address.getString("province") city = address.getString("city") } } } catch { case e: Exception => {} } finally { // 每一次数据请求之后,关闭连接 response.close() httpClient.close() } (province, city) } } object UDFTest1 { def main(args: Array[String]): Unit = { val sparkSession: SparkSession = SparkSession.builder() .appName("UDFTest1") .master("local") .getOrCreate() import sparkSession.implicits._ // 118.396128,"latitude":35.916527 val dataset: Dataset[(String, String)] = sparkSession.createDataset(List(("a", "118.396128,35.916527"), ("b", "118.596128,35.976527"))) val dataFrame: DataFrame = dataset.toDF("uid", "location") dataFrame.createTempView("v_location") sparkSession.udf.register("geo", GeoFunc.geo) dataFrame.show() val dataFrame1: DataFrame = sparkSession.sql( """ |select |uid, |geo(loc1, loc2) as province_city |from |( | select | uid, | cast(loc_pair[0] as double) as loc1, | cast(loc_pair[1] as double) as loc2 | from | ( | select | uid, | split(location, '[,]') as loc_pair | from | v_location | ) |) |""".stripMargin) dataFrame1.show() sparkSession.stop() } } /* * -- 先切割数据 select city, split(location, '[,]') as loc_pair from v_location -- 将数据转换为double select uid, cast(loc_pair[0] as double) as loc1, cast(loc_pair[1] as double) as loc2 from ( select uid, split(location, '[,]') as loc_pair from v_location ) -- 调用自定义函数进行数据查询 select uid, geo(loc1, loc2) as province_city from ( select uid, cast(loc_pair[0] as double) as loc1, cast(loc_pair[1] as double) as loc2 from ( select uid, split(location, '[,]') as loc_pair from v_location ) ) * * */

    2.2 自定义拼接字符串函数

    object UDF_CustomConcat { def main(args: Array[String]): Unit = { val sparkSession: SparkSession = SparkSession.builder() .appName("UDF_CustomConcat") .master("local") .getOrCreate() import sparkSession.implicits._ // 创建dataset,再转换为dataframe val dataset: Dataset[(String, String)] = sparkSession.createDataset(List(("湖南", "长沙"), ("江西", "南昌"), ("湖北", "武汉"))) val dataFrame: DataFrame = dataset.toDF("province", "city") // 自定义函数,注意函数名尽量规范,见名知意一些 val udf_func = (arg1:String, arg2:String) => { arg1 + "-" + arg2 } // 注册自定义函数,注意这个是临时注册,只有这个代码中才可以生效 sparkSession.udf.register("udf_func", udf_func) // 使用sql之前,先注册视图 dataFrame.createTempView("v_test") val dataFrame1: DataFrame = sparkSession.sql("select udf_func(province, city) as concat_result from v_test;") dataFrame1.show() sparkSession.close() } }

    2.3 将Ip地址转换为省(市区)地理位置信息

    环境准备 ip字典(比较大,只展示部分,可以去淘宝、拼多多、咸鱼等上购买此类数据资产) 1.4.8.0|1.4.127.255|17041408|17072127|亚洲|中国|广东|广州||电信|440100|China|CN|113.280637|23.125178 1.8.0.0|1.8.255.255|17301504|17367039|亚洲|中国|北京|北京|海淀|北龙中网|110108|China|CN|116.29812|39.95931 1.10.0.0|1.10.7.255|17432576|17434623|亚洲|中国|广东|广州||电信|440100|China|CN|113.280637|23.125178 1.10.8.0|1.10.9.255|17434624|17435135|亚洲|中国|福建|福州||电信|350100|China|CN|119.306239|26.075302 1.10.11.0|1.10.15.255|17435392|17436671|亚洲|中国|福建|福州||电信|350100|China|CN|119.306239|26.075302 1.10.16.0|1.10.127.255|17436672|17465343|亚洲|中国|广东|广州||电信|440100|China|CN|113.280637|23.125178 1.12.0.0|1.12.255.255|17563648|17629183|亚洲|中国|北京|北京||方正宽带|110100|China|CN|116.405285|39.904989 1.13.0.0|1.13.71.255|17629184|17647615|亚洲|中国|吉林|长春||方正宽带|220100|China|CN|125.3245|43.886841 1.13.72.0|1.13.87.255|17647616|17651711|亚洲|中国|吉林|吉林||方正宽带|220200|China|CN|126.55302|43.843577 1.13.88.0|1.13.95.255|17651712|17653759|亚洲|中国|吉林|长春||方正宽带|220100|China|CN|125.3245|43.886841 1.13.96.0|1.13.127.255|17653760|17661951|亚洲|中国|天津|天津||方正宽带|120100|China|CN|117.190182|39.125596 1.13.128.0|1.13.191.255|17661952|17678335|亚洲|中国|吉林|长春||方正宽带|220100|China|CN|125.3245|43.886841 1.13.192.0|1.14.95.255|17678336|17719295|亚洲|中国|辽宁|大连||方正宽带|210200|China|CN|121.618622|38.91459 1.14.96.0|1.14.127.255|17719296|17727487|亚洲|中国|辽宁|鞍山||方正宽带|210300|China|CN|122.995632|41.110626 1.14.128.0|1.14.191.255|17727488|17743871|亚洲|中国|上海|上海||方正宽带|310100|China|CN|121.472644|31.231706 1.14.192.0|1.14.223.255|17743872|17752063|亚洲|中国|吉林|长春||方正宽带|220100|China|CN|125.3245|43.886841 日志数据 20090121000732398422000|122.73.114.24|aa.991kk.com|/html/taotuchaoshi/2009/0120/7553.html|Mozilla/5.0 (Windows; U; Windows NT 5.1; zh-CN; rv:1.9.0.1) Gecko/2008070208 Firefox/3.0.1|http://aa.991kk.com/html/taotuchaoshi/index.html| 20090121000732420671000|115.120.14.96|image.baidu.com|/i?ct=503316480&z=0&tn=baiduimagedetail&word=%B6%AF%CE%EF%D4%B0+%B3%A4%BE%B1%C2%B9&in=32346&cl=2&cm=1&sc=0&lm=-1&pn=527&rn=1&di=2298496252&ln=615|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; GTB5; TencentTraveler 4.0)|http://image.baidu.com/i?tn=baiduimage&ct=201326592&cl=2&lm=-1&pv=&word=%B6%AF%CE%EF%D4%B0+%B3%A4%BE%B1%C2%B9&z=0&rn=21&pn=525&ln=615|BAIDUID=C1B0C0D4AA4A7D1BF9A0F74C4B727970:FG=1; BDSTAT=c3a929956cf1d97d5982b2b7d0a20cf431adcbef67094b36acaf2edda2cc5bc0; BDUSS=jBXVi1tQ3ZTSDJiflVHRERTSUNiYUtGRmNrWkZTYllWOEJZSk1-V0xFNU1lcDFKQkFBQUFBJCQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEztdUlM7XVJZ; BDSP=e114da18f3deb48fff2c9a8ef01f3a292df5e0fe2b24463340405da85edf8db1cb1349540923dd54564e9258d109b3de9c82d158ccbf6c81800a19d8bc3eb13533fa828ba61ea8d3fd1f4134970a304e251f95cad1c8a786c9177f3e6709c93d72cf5979; iCast_Rotator_1_1=1232467533578; iCast_Rotator_1_2=1232467564718 20090121000732511280000|115.120.16.98|ui.ptlogin2.qq.com|/cgi-bin/login?appid=7000201&target=self&f_url=loginerroralert&s_url=http://minigame.qq.com/login/flashlogin/loginok.html|Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; QQDownload 1.7; Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1) )|http://minigame.qq.com/act/rose0901/?aid=7000201&ADUIN=563586856&ADSESSION=1232467131&ADTAG=CLIENT.QQ.1833_SvrPush_Url.0| 20090121000732967450000|117.101.219.112|list.taobao.com|/browse/50010404-302903/n-1----------------------0---------yes-------g,giydmmjxhizdsnjwgy5tgnbsgyzdumjshe4dmoa--g,giydmmjxhlk6xp63hmztimrwgi5mnvonvc764kbsfu2gg3jj--g,ojsxgzlsozsv64dsnfrwkwzvgawdcmbqlu-------------------40-grid-ratesum-0-all-302903.htm|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30)|http://list.taobao.com/browse/50010404-302903/n-1----------------------0---------yes-------g,giydmmjxhizdsnjwgy5tgnbsgyzdumjshe4dmoa--g,giydmmjxhlk6xp63hmztimrwgi5mnvonvc764kbsfu2gg3jj---------------------42-grid-ratesum-0-all-302903.htm| 20090121000733245014000|117.101.227.3|se.360.cn|/web/navierr.htm?url=http://www.3320.net/blib/c/24/11839/&domain=www.3320.net&code=403&pid=sesoft&tabcount=7|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7; 360SE)||B=ID=435431224878393:V=2:S=8f59056144; __utma=148900148.1624674999336435000.1224880187.1226546993.1229991277.5; __utmz=148900148.1224880187.1.1.utmcsr=(direct) 20090121000733290585000|117.101.206.175|wwww.17kk.net|/0OO000OO00O00OOOOO0/new/cjbbs/zx1.htm|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7)|http://wwww.17kk.net/0OO000OO00O00OOOOO0/new/cjbbs/zzx.htm|rtime=11; ltime=1232384469187; cnzz_eid=54742851-1228495798-http%3A//wwww.17kk.net/0OO000OO00O00OOOOO0/tongji1_l7kk.htm; cck_lasttime=1232381515031; cck_count=0; cnzz_a508803=8; vw508803=%3A80391793%3A; sin508803=none; ASPSESSIONIDQQAQQCRT=GKKKBIFCLAJPKGHGEKDEAPPB; ASPSESSIONIDQSCRQCQS=BCIBBIFCMLLPGEPBCFMEHGOA; ASPSESSIONIDSQBSRDRT=GPLKBIFCJBIAHLLBJLDDANGN; ASPSESSIONIDSQBRRDRS=AHLDAIFCDIINIGLMEEJJDGDN; __utma=152924281.4523785370259723000.1228495189.1232381092.1232466255.16; __utmb=152924281.8.10.1232466255; __utmz=152924281.1228495189.1.1.utmcsr=(direct) 20090121000733387555000|117.101.206.175|wwww.17kk.net|/0OO000OO00O00OOOOO0/new/6cheng/nnts/180/sport.htm|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7)|http://wwww.17kk.net/0OO000OO00O00OOOOO0/new/6cheng/nnts/180/z8.htm|rtime=11; ltime=1232384469187; cnzz_eid=54742851-1228495798-http%3A//wwww.17kk.net/0OO000OO00O00OOOOO0/tongji1_l7kk.htm; cck_lasttime=1232381515031; cck_count=0; cnzz_a508803=8; vw508803=%3A80391793%3A; sin508803=none; ASPSESSIONIDQQAQQCRT=GKKKBIFCLAJPKGHGEKDEAPPB; ASPSESSIONIDQSCRQCQS=BCIBBIFCMLLPGEPBCFMEHGOA; ASPSESSIONIDSQBSRDRT=GPLKBIFCJBIAHLLBJLDDANGN; ASPSESSIONIDSQBRRDRS=AHLDAIFCDIINIGLMEEJJDGDN; __utma=152924281.4523785370259723000.1228495189.1232381092.1232466255.16; __utmb=152924281.8.10.1232466255; __utmz=152924281.1228495189.1.1.utmcsr=(direct) 20090121000733393911000|115.120.10.168|my.51.com|/port/ajax/main.accesslog.php|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7)|http://my.51.com/| 20090121000734192650000|115.120.9.235|www.baidu.com|/s?tn=mzmxzgx_pg&wd=xiao77%C2%DB%CC%B3|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7; Avant Browser; CIBA)|http://www.250cctv.cn/|BAIDUID=80DA16918ED68645445A6837338DBC5C:FG=1; BDSTAT=805379474b3ed4a4ab64034f78f0f736afc379310855b319ebc4b74541a9d141; BD_UTK_DVT=1; BDRCVFR[9o0so1JMIzY]=bTm-Pk1nd0D00; BDRCVFR[ZusMMNJpUDC]=QnHQ0TLSot3ILILQWcdnAPWIZm8mv3 20090121000734299056000|125.213.97.6|haort.com|/Article/200901/2071_3.html|Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; QQDownload 1.7; GTB5)|http://haort.com/Article/200901/2071_2.html|krviewcurc=1; krvedlaid=4285; kcc_169767kanrt=90; AJSTAT_ok_times=2; rtime=0; ltime=1220372703640; cnzz_eid=3485291-http%3A//rentiart.com/js/3.htm; krviewcurc=2; krvedlaid=3720; cck_lasttime=1232468301734; cck_count=0; AJSTAT_ok_pages=14; Cookie9=PopAnyibaSite; kcc_169767kanrt=39 20090121000734469862000|117.101.213.66|book.tiexue.net|/Content_620501.html|Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1)|http://book.tiexue.net/Content_619975.html|BookBackModel=5%2C1; RHistory=11225%2C%u7279%u6218%u5148%u9A71; __utma=247579266.194628209.1232339801.1232350177.1232464272.3; __utmb=247579266; __utmz=247579266.1232339801.1.1.utmccn=(direct) 20090121000734529619000|115.120.0.192|www.cqpa.gov.cn|/u/cqpa/news_12757.shtml|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1)|http://www.cqpa.gov.cn/u/cqpa/|ASPSESSIONIDQSRAAAST=LGAIOKNCHPHMKALKIHPODCOB 20090121000734819099000|117.101.225.140|jifen.qq.com|/static/mart/shelf/9.shtml|Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; QQDownload 1.7; TencentTraveler 4.0; CIBA; .NET CLR 2.0.50727)|http://jifen.qq.com/mart/shelf_list.shtml?9|pvid=47052875; o_cookie=996957123; flv=9.0; pt2gguin=o0361474804; ptcz=1cc6f06a90bb8d1f53069184d85dd4d01dd8ca38eb7eb2fa615548538f133ede; r_cookie=912035936314; sc_cookie_floating_refresh241=3; icache=MMBMACEFG; uin_cookie=361474804; euin_cookie=AQAYuAH3EXdauOugz/OMzWPIssCyb0d3XzENGAAAAADefqSBU4unTT//nt3WNqaSQ2R44g==; pgv=ssid=s2273451828&SPATHTAG=CLIENT.PURSE.MyPurse.JifenInfo&KEYPATHTAG=2.1.1; verifysession=9b3f4c872a003e70cfe2ef5de1a62a3d862a448fd2f5b1b032512256fbd832fd7365b7d7619ef2ca; uin=o0361474804; skey=@BpkD0OWtL; JifenUserId=361474804; ACCUMULATE=g1qjCmEMXxtoOc1g00000681; _rsCS=1 20090121000735126951000|115.120.4.164|www.5webgame.com|/bbs/2fly_gift.php|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; .NET CLR 1.1.4322)|http://www.5webgame.com/bbs/viewthread.php?tid=43&extra=page%3D1| 20090121000735482286000|125.213.97.254|tieba.baidu.com|/f?kz=527788861|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; TencentTraveler )|http://tieba.baidu.com/f?ct=&tn=&rn=&pn=&lm=&sc=&kw=%D5%DB%CC%DA&rs2=0&myselectvalue=1&word=%D5%DB%CC%DA&tb=on|BAIDUID=D87E9C0E1E427AD5EEB37C6CC4B9C5CE:FG=1; BD_UTK_DVT=1; AdPlayed=true 20090121000735619376000|115.120.3.253|m.163.com|/xxlwh/|Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; InfoPath.1)|http://blog.163.com/xxlwh/home/| 20090121000735819656000|115.120.13.149|2008.wg999.com|/|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7; TencentTraveler 4.0; Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1) )||ystat_bc_827474=23538508893341337937; ystat_bc_832488=29857733243775586653 object UDF_IP2Location { def main(args: Array[String]): Unit = { val sparkSession: SparkSession = SparkSession.builder() .appName("UDF_IP2Location") .master("local") .getOrCreate() import sparkSession.implicits._ import org.apache.spark.sql.functions._ // 读取文本文件 val ipRules: Dataset[String] = sparkSession.read.textFile("E:\\DOITLearning\\12.Spark\\ip_location_dict.txt") // 将读取的ip规则字典加载到driver端,注意这里要做分布式缓存,省去join成本,也就是map端缓存 // 数据转换,排序,去重,采集回driver端 val ipRulesInDriver: Array[(Long, Long, String, String)] = ipRules.map(line => { val strings: Array[String] = line.split("[|]") val startIpNum: Long = strings(2).toLong val endIpNum: Long = strings(3).toLong val province: String = strings(6) val city: String = strings(7) (startIpNum, endIpNum, province, city) }).distinct() .sort($"_1" asc) .collect() // 注册广播变量 val broadcastRefInDriver: Broadcast[Array[(Long, Long, String, String)]] = sparkSession.sparkContext.broadcast(ipRulesInDriver) // 自定义函数 val ip2Location = (ip:String) => { val ipNumber: Long = IpUtils.ip2Long(ip) // 这里产生了闭包 val ipRulesInExecutor: Array[(Long, Long, String, String)] = broadcastRefInDriver.value // 注意,文本文件中数据本身已经做过排序,但一般为了保险,一般都会再做一次排序 val index: Int = IpUtils.binarySearch(ipRulesInExecutor, ipNumber) var province: String = null if(index > 0) { province = ipRulesInExecutor(index)._3 } province } // 将自定义函数注册为udf函数 sparkSession.udf.register("ip2Location", ip2Location) // 读取需要处理的日志数据 val logLines: Dataset[String] = sparkSession.read.textFile("E:\\DOITLearning\\12.Spark\\ipaccess.log") val dataFrame: DataFrame = logLines.map(line => { val strings: Array[String] = line.split("[|]") val ipStr: String = strings(1) ipStr }).toDF("ip") // 将dataFrame注册为临时视图,方便做数据查询 dataFrame.createTempView("v_ips") // 执行sql语句 sparkSession.sql("select ip, ip2Location(ip) as location from v_ips") .limit(15) .show() sparkSession.close() } } // 这是一个工具类,主要是将ip地址转换为长整型以及二分查找 object IpUtils { /** * 将IP地址转成十进制 * * @param ip * @return */ def ip2Long(ip: String): Long = { val fragments = ip.split("[.]") var ipNum = 0L for (i <- 0 until fragments.length) { ipNum = fragments(i).toLong | ipNum << 8L } ipNum } /** * 二分法查找 * 注意,scala中,如果是递归函数调用,必须要用return返回值,否则会导致函数无法跳出的问题 * @param lines * @param ip * @return */ def binarySearch(lines: ArrayBuffer[(Long, Long, String, String)], ip: Long): Int = { var low = 0 //起始 var high = lines.length - 1 //结束 while (low <= high) { val middle = (low + high) / 2 if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2)) return middle if (ip < lines(middle)._1) high = middle - 1 else { low = middle + 1 } } -1 //没有找到 } def binarySearch(lines: Array[(Long, Long, String, String)], ip: Long): Int = { var low = 0 //起始 var high = lines.length - 1 //结束 while (low <= high) { val middle = (low + high) / 2 if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2)) return middle if (ip < lines(middle)._1) high = middle - 1 else { low = middle + 1 } } -1 //没有找到 } }

    2.4 将订单数据中经纬度转换为地理位置

    环境准备 数据 {"cid": 1, "money": 600.0, "longitude":116.397128,"latitude":39.916527,"oid":"o123", } "oid":"o112", "cid": 3, "money": 200.0, "longitude":118.396128,"latitude":35.916527} {"oid":"o124", "cid": 2, "money": 200.0, "longitude":117.397128,"latitude":38.916527} {"oid":"o125", "cid": 3, "money": 100.0, "longitude":118.397128,"latitude":35.916527} {"oid":"o127", "cid": 1, "money": 100.0, "longitude":116.395128,"latitude":39.916527} {"oid":"o128", "cid": 2, "money": 200.0, "longitude":117.396128,"latitude":38.916527} {"oid":"o129", "cid": 3, "money": 300.0, "longitude":115.398128,"latitude":35.916527} {"oid":"o130", "cid": 2, "money": 100.0, "longitude":116.397128,"latitude":39.916527} {"oid":"o131", "cid": 1, "money": 100.0, "longitude":117.394128,"latitude":38.916527} {"oid":"o132", "cid": 3, "money": 200.0, "longitude":118.396128,"latitude":35.916527} object UDF_IP2Location { def main(args: Array[String]): Unit = { val sparkSession: SparkSession = SparkSession.builder() .appName("UDF_IP2Location") .master("local") .getOrCreate() import sparkSession.implicits._ import org.apache.spark.sql.functions._ // 读取文本文件 val ipRules: Dataset[String] = sparkSession.read.textFile("E:\\DOITLearning\\12.Spark\\ip_location_dict.txt") // 将读取的ip规则字典加载到driver端,注意这里要做分布式缓存,省去join成本,也就是map端缓存 // 数据转换,排序,去重,采集回driver端 val ipRulesInDriver: Array[(Long, Long, String, String)] = ipRules.map(line => { val strings: Array[String] = line.split("[|]") val startIpNum: Long = strings(2).toLong val endIpNum: Long = strings(3).toLong val province: String = strings(6) val city: String = strings(7) (startIpNum, endIpNum, province, city) }).distinct() .sort($"_1" asc) .collect() // 注册广播变量 val broadcastRefInDriver: Broadcast[Array[(Long, Long, String, String)]] = sparkSession.sparkContext.broadcast(ipRulesInDriver) // 自定义函数 val ip2Location = (ip:String) => { val ipNumber: Long = IpUtils.ip2Long(ip) // 这里产生了闭包 val ipRulesInExecutor: Array[(Long, Long, String, String)] = broadcastRefInDriver.value // 注意,文本文件中数据本身已经做过排序,但一般为了保险,一般都会再做一次排序 val index: Int = IpUtils.binarySearch(ipRulesInExecutor, ipNumber) var province: String = null if(index > 0) { province = ipRulesInExecutor(index)._3 } province } // 将自定义函数注册为udf函数 sparkSession.udf.register("ip2Location", ip2Location) // 读取需要处理的日志数据 val logLines: Dataset[String] = sparkSession.read.textFile("E:\\DOITLearning\\12.Spark\\ipaccess.log") val dataFrame: DataFrame = logLines.map(line => { val strings: Array[String] = line.split("[|]") val ipStr: String = strings(1) ipStr }).toDF("ip") // 将dataFrame注册为临时视图,方便做数据查询 dataFrame.createTempView("v_ips") // 执行sql语句 sparkSession.sql("select ip, ip2Location(ip) as location from v_ips") .limit(15) .show() sparkSession.close() } } // 这是一个工具类,主要是将ip地址转换为长整型以及二分查找 object IpUtils { /** * 将IP地址转成十进制 * * @param ip * @return */ def ip2Long(ip: String): Long = { val fragments = ip.split("[.]") var ipNum = 0L for (i <- 0 until fragments.length) { ipNum = fragments(i).toLong | ipNum << 8L } ipNum } /** * 二分法查找 * 注意,scala中,如果是递归函数调用,必须要用return返回值,否则会导致函数无法跳出的问题 * @param lines * @param ip * @return */ def binarySearch(lines: ArrayBuffer[(Long, Long, String, String)], ip: Long): Int = { var low = 0 //起始 var high = lines.length - 1 //结束 while (low <= high) { val middle = (low + high) / 2 if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2)) return middle if (ip < lines(middle)._1) high = middle - 1 else { low = middle + 1 } } -1 //没有找到 } def binarySearch(lines: Array[(Long, Long, String, String)], ip: Long): Int = { var low = 0 //起始 var high = lines.length - 1 //结束 while (low <= high) { val middle = (low + high) / 2 if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2)) return middle if (ip < lines(middle)._1) high = middle - 1 else { low = middle + 1 } } -1 //没有找到 } }

    2.5 自定义聚合函数(适用于Spark1.0 2.0)

    环境准备 数据 name,salary,dept jack,200.2,develop tom,301.5,finance sunny,412,operating hanson,50000,ceo tompson,312,operating water,700.2,develop money,500.2,develop 求平均工资 import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} // 这个类在1.0 2.0的spark版本还可以使用,spark 3.0版本已经废弃,使用更新的接口,相对更加精简 class CustomAvgFunction extends UserDefinedAggregateFunction { // 这是指输入的数据类型,因为这个自定义函数用来计算平均工资,所以输入就是一个数据,而且是double类型 override def inputSchema: StructType = StructType(List( StructField("sal", DataTypes.DoubleType) )) // 这是中间结果数据类型,就是总工资,人员个数 override def bufferSchema: StructType = StructType(List( StructField("sum_sal", DataTypes.DoubleType), StructField("counts", DataTypes.IntegerType) )) // 这是返回的数据类型,平均工资,还是double override def dataType: DataType = DataTypes.DoubleType // 确定性,这里指输入和输出数据类型是否一样 override def deterministic: Boolean = true // 初始值,类似于RDD的combineBykey的用法 override def initialize(buffer: MutableAggregationBuffer): Unit = { // 初始工资,就是0.0开始。这里需要显式指定0.0,会自动推导出是double类型0,所以不能是0,必须是0.0 // 注意如果中间结果是乘法,除法,初始值就是1,注意灵活区别 buffer(0) = 0.0 buffer(1) = 0 // 人员个数 } // 每处理一条数据,在每个分区进行的局部运算 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { val in: Double = input.getDouble(0) // 添加一份工资数据 buffer(0) = buffer.getDouble(0) + in // 次数累加1 buffer(1) = buffer.getInt(1) + 1 } // 每个分区的聚合结果操作 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { // 每个分区的总工资累加 buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0) // 每个分区的次数累加 buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1) } // 最后的聚合操作 override def evaluate(buffer: Row): Any = { // 总工资除以次数,如果要预防错误,可以判断分母为0的场景 buffer.getDouble(0) / buffer.getInt(1) } } object UDF_Custom_AVG_Test { def main(args: Array[String]): Unit = { val sparkSession: SparkSession = SparkSession.builder() .appName("UDF_Custom_AVG_Test") .master("local") .getOrCreate() // 读取csv文件,注意option中还可以只当分割符号等信息 val dataFrame: DataFrame = sparkSession.read .option("header", true) .option("inferschema", true) .csv("E:\\DOITLearning\\12.Spark\\employinfo.csv") // 创建临时视图,才能执行sql语句 dataFrame.createTempView("v_emp") // 注册自定义函数,这个方法在spark 3.0中被指明废弃,但还可以使用 sparkSession.udf.register("my_avg", new CustomAvgFunction) // 执行sql,内部运行自定义函数 val dataFrame1: DataFrame = sparkSession.sql("select dept, my_avg(salary) as salary_avg from v_emp group by dept") dataFrame1.show() sparkSession.close() } }

    2.6 自定义聚合函数(适用于Spark 3.0)

    求平均工资,和2.5一样的数据和需求 object UDF_Custom_AVG_Test2 { def main(args: Array[String]): Unit = { val sparkSession: SparkSession = SparkSession.builder() .appName("UDF_Custom_AVG_Test2") .master("local") .getOrCreate() // 读取csv文件,注意option中还可以只当分割符号等信息 val dataFrame: DataFrame = sparkSession.read .option("header", true) .option("inferschema", true) .csv("E:\\DOITLearning\\12.Spark\\employinfo.csv") // 创建临时视图,才能执行sql语句 dataFrame.createTempView("v_emp") import org.apache.spark.sql.functions._ val myAVGFunct = new Aggregator[Double, (Double, Int), Double] { // 初始值 override def zero: (Double, Int) = (0.0, 0) // 分区内聚合 override def reduce(b: (Double, Int), a: Double): (Double, Int) = { (b._1 + a, b._2 + 1) } // 分区之间结果聚合 override def merge(b1: (Double, Int), b2: (Double, Int)): (Double, Int) = { (b1._1 + b2._1, b1._2 + b2._2) } // 最后结果处理 override def finish(reduction: (Double, Int)): Double = { reduction._1 / reduction._2 } // 中间结果如何序列化编码 override def bufferEncoder: Encoder[(Double, Int)] = { Encoders.tuple(Encoders.scalaDouble, Encoders.scalaInt) } // 数据结果输出如何进行序列化编码 override def outputEncoder: Encoder[Double] = { Encoders.scalaDouble } } // 注册自定义方法 // 新的自定义聚合方法,需要使用udaf将对象转换一下 sparkSession.udf.register("my_avg", udaf(myAVGFunct)) val dataFrame1: DataFrame = sparkSession.sql("select dept, my_avg(salary) as salary_avg from v_emp group by dept") dataFrame1.show() sparkSession.close() } }

    2.7 求几何平均数

    object UDF_Custom_AVG_Test3 { def main(args: Array[String]): Unit = { val sparkSession: SparkSession = SparkSession.builder() .appName("UDF_Custom_AVG_Test3") .master("local") .getOrCreate() val nums: Dataset[lang.Long] = sparkSession.range(1, 10) nums.createTempView("v_nums") import org.apache.spark.sql.functions._ // 自定义聚合函数 val agg = new Aggregator[Long, (Long, Int), Double]() { // 这里是要求集合平均值,初始值会不一样 override def zero: (Long, Int) = (1, 1) // 中间值处理 override def reduce(b: (Long, Int), a: Long): (Long, Int) = { (b._1 * a, b._2 + 1) } // 分区之间结果聚合处理 override def merge(b1: (Long, Int), b2: (Long, Int)): (Long, Int) = { (b1._1 * b2._1, b1._2 + b2._2) } // 最后结果处理 override def finish(reduction: (Long, Int)): Double = { Math.pow(reduction._1.toDouble, 1 / reduction._2.toDouble) } // 中间结果序列化编码 override def bufferEncoder: Encoder[(Long, Int)] = { Encoders.tuple(Encoders.scalaLong, Encoders.scalaInt) } // 输出结果编码 override def outputEncoder: Encoder[Double] = { Encoders.scalaDouble } } // 注册方法 sparkSession.udf.register("geo_mean", udaf(agg)) val dataFrame: DataFrame = sparkSession.sql("select geo_mean(id) from v_nums") dataFrame.show() // 可以打印出逻辑计划,物理计划,以及其优化思路 dataFrame.explain(true) sparkSession.close() } }

    2.8 总结

    自定义函数,就跟编码时自定义的代码方法一样,可以根据业务需求做调整如果需要复用,可以将其抽离到一个公共文件中,方便复用自定义函数使用前需要注册一下dataframe本身要适用sql方式处理,需要先注册为视图,可以是临时的,也可以是全局的UDF、UDTF、UDAF概念和Hive中一样,也都可以自定义,最后在sql中使用
    Processed: 0.015, SQL: 8