1、根底操作
package com.journey.sql;import com.alibaba.fastjson.JSON;import com.journey.sql.bean.User;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.function.Function;import org.apache.spark.api.java.function.MapFunction;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Encoder;import org.apache.spark.sql.Encoders;import org.apache.spark.sql.Row;import org.apache.spark.sql.RowFactory;import org.apache.spark.sql.SparkSession;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.Arrays;import java.util.Collections;import java.util.List;import static org.apache.spark.sql.functions.col;public class SparkSQLTest { public static void main(String[] args) throws Exception { SparkSession spark = SparkSession .builder() .appName("Demo") .master("local[*]") .getOrCreate(); // 读取 json 文件 创立 DataFrame {"username": "lisi","age": 18},DataFrame是一种非凡的Dataset,行是Row Dataset<Row> df = spark.read().json("datas/sql/user.json"); // 展现表构造 + 数据 df.show(); // 打印schema构造 df.printSchema(); // 间接select df.select("username").show(); // 加1 df.select(col("username"), col("age").plus(1)).show(); // 过滤age大于19 df.filter(col("age").gt(19)).show(); // 统计age的个数 df.groupBy("age").count().show(); df.createOrReplaceTempView("user"); // 应用sql来查问 Dataset<Row> sqlDF = spark.sql("select * from user"); sqlDF.show(); // 注册DataFrame作为一个全局的长期视图 df.createGlobalTempView("user2"); spark.sql("select * from global_temp.user2").show(); spark.newSession().sql("select * from global_temp.user2").show(); /** * 数据集与 RDD 相似,然而,它们不应用 Java 序列化或 Kryo,而是应用专门的编码器来序列化对象以进行解决或通过网络传输。 * 尽管编码器和规范序列化都负责将对象转换为字节,但编码器是动静生成的代码,并应用一种格局,容许 Spark 执行许多操作, * 如过滤、排序和散列,而无需将字节反序列化回对象。 */ // 留神 : User不能是static润饰 User user = new User("qiaozhanwei", 20); Encoder<User> userEncoder = Encoders.bean(User.class); Dataset<User> javaBeanDS = spark.createDataset(Collections.singletonList(user), userEncoder); javaBeanDS.show(); Encoder<Integer> integerEncoder = Encoders.INT(); Dataset<Integer> primitiveDS = spark.createDataset(Arrays.asList(1, 2, 3), integerEncoder); Dataset<Integer> transformedDS = primitiveDS.map( (MapFunction<Integer, Integer>) value -> value + 1, integerEncoder); // java: 不兼容的类型: java.lang.Object无奈转换为java.lang.Integer[],跑不通 // Integer[] collect = transformedDS.collect(); transformedDS.show(); Dataset<User> userDS = spark.read().json("datas/sql/user.json").as(userEncoder); userDS.show(); JavaRDD<User> userRDD = spark.read().textFile("datas/sql/user.json") .javaRDD() .map(line -> { User userInfo = JSON.parseObject(line, User.class); return userInfo; }); Dataset<Row> user3DF = spark.createDataFrame(userRDD, User.class); user3DF.createOrReplaceTempView("user3"); List<User> userList = new ArrayList<>(); userList.add(new User("haha", 30)); Dataset<Row> dataFrame = spark.createDataFrame(userList, User.class); dataFrame.show(); Dataset<Row> teenagerDF = spark.sql("select * from user3 where age between 13 and 20"); Encoder<String> stringEncoder = Encoders.STRING(); Dataset<String> teenagerNamesByIndexDF = teenagerDF.map(new MapFunction<Row, String>() { @Override public String call(Row value) throws Exception { return "Name : " + value.getString(1); } }, stringEncoder); teenagerNamesByIndexDF.show(); Dataset<String> teenagerNamesByFieldDF = teenagerDF.map( (MapFunction<Row, String>) row -> "Name: " + row.<String>getAs("userName"), stringEncoder); teenagerNamesByFieldDF.show(); // 定义用户名字段类型 StructField userNameField = DataTypes.createStructField("name", DataTypes.StringType, true); // 定义年龄字段类型 StructField ageField = DataTypes.createStructField("age", DataTypes.IntegerType, true); List<StructField> fields = new ArrayList<>(); fields.add(userNameField); fields.add(ageField); StructType schema = DataTypes.createStructType(fields); JavaRDD<String> user2RDD = spark.sparkContext().textFile("datas/sql/user.txt", 2).toJavaRDD(); JavaRDD<Row> rowRDD = user2RDD.map(new Function<String, Row>() { @Override public Row call(String value) throws Exception { String[] fields = value.split(","); return RowFactory.create(fields[0], Integer.parseInt(fields[1])); } }); Dataset<Row> user4DF = spark.createDataFrame(rowRDD, schema); user4DF.createOrReplaceTempView("user4"); spark.sql("select * from user4").show(); spark.stop(); }}
RDD、DataFrame和Dataset的关系及转换
2、UDF函数
标量函数
package com.journey.sql;import com.journey.sql.bean.User;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.SparkSession;import org.apache.spark.sql.api.java.UDF1;import org.apache.spark.sql.types.DataTypes;import java.util.ArrayList;import java.util.List;public class ScalarFunctionTest { public static void main(String[] args) { SparkSession spark = SparkSession .builder() .appName("ScalarFunctionTest") .master("local[*]") .getOrCreate(); // 依据参数有UDF2.... UDF1<String, String> myUdf = new UDF1<String, String>() { @Override public String call(String value) throws Exception { return "baidu-" + value; } }; // 函数注册 spark.udf().register("myUdf", myUdf, DataTypes.StringType); List<User> userList = new ArrayList<>(); userList.add(new User("zhangsan", 20)); Dataset<Row> df = spark.createDataFrame(userList, User.class); df.createOrReplaceTempView("user"); spark.sql("select myUdf(userName) from user").show(); spark.stop(); }}
聚合函数
弱类型
package com.journey.sql;import org.apache.spark.sql.Row;import org.apache.spark.sql.expressions.MutableAggregationBuffer;import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;import org.apache.spark.sql.types.DataType;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.List;// 无类型public class MyAverage1 extends UserDefinedAggregateFunction { private StructType inputSchema; private StructType bufferSchema; public MyAverage1() { List<StructField> inputFields = new ArrayList<>(); inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true)); inputSchema = DataTypes.createStructType(inputFields); List<StructField> bufferFields = new ArrayList<>(); bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true)); bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true)); bufferSchema = DataTypes.createStructType(bufferFields); } @Override public StructType inputSchema() { return inputSchema; } @Override public StructType bufferSchema() { return bufferSchema; } @Override public DataType dataType() { return DataTypes.DoubleType; } // 此函数是否总是在雷同的输出上返回雷同的输入 @Override public boolean deterministic() { return true; } // 初始化 @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0, 0L); buffer.update(1, 0L); } // 中间状态更新 @Override public void update(MutableAggregationBuffer buffer, Row input) { if (!input.isNullAt(0)) { long updatedSum = buffer.getLong(0) + input.getLong(0); long updatedCount = buffer.getLong(1) + 1; buffer.update(0, updatedSum); buffer.update(1, updatedCount); } } // 合并 @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { long mergedSum = buffer1.getLong(0) + buffer2.getLong(0); long mergedCount = buffer1.getLong(1) + buffer2.getLong(1); buffer1.update(0, mergedSum); buffer1.update(1, mergedCount); } // 计算结果 @Override public Double evaluate(Row buffer) { return (double) buffer.getLong(0) / buffer.getLong(1); }}
强类型
package com.journey.sql;import java.io.Serializable;public class Average implements Serializable { private long sum; private long count; public Average() {} public Average(long sum, long count) { this.sum = sum; this.count = count; } public long getSum() { return sum; } public void setSum(long sum) { this.sum = sum; } public long getCount() { return count; } public void setCount(long count) { this.count = count; }}
package com.journey.sql;import com.journey.sql.bean.User;import org.apache.spark.sql.Encoder;import org.apache.spark.sql.Encoders;import org.apache.spark.sql.expressions.Aggregator;public class MyAverage2 extends Aggregator<User, Average, Double> { @Override public Average zero() { return new Average(0L, 0L); } @Override public Average reduce(Average buffer, User user) { long newSum = buffer.getSum() + user.getAge(); long newCount = buffer.getCount() + 1; buffer.setSum(newSum); buffer.setCount(newCount); return buffer; } @Override public Average merge(Average b1, Average b2) { long mergedSum = b1.getSum() + b2.getSum(); long mergedCount = b1.getCount() + b2.getCount(); b1.setSum(mergedSum); b1.setCount(mergedCount); return b1; } @Override public Double finish(Average reduction) { return (double) reduction.getSum() / reduction.getCount(); } @Override public Encoder<Average> bufferEncoder() { return Encoders.bean(Average.class); } @Override public Encoder<Double> outputEncoder() { return Encoders.DOUBLE(); }}
package com.journey.sql;import com.journey.sql.bean.User;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Encoder;import org.apache.spark.sql.Encoders;import org.apache.spark.sql.Row;import org.apache.spark.sql.SparkSession;import org.apache.spark.sql.TypedColumn;import java.util.ArrayList;import java.util.List;public class AggregationsTest { public static void main(String[] args) throws Exception { SparkSession spark = SparkSession .builder() .appName("AggregationsTest") .master("local[*]") .getOrCreate(); spark.udf().register("myAverage1", new MyAverage1()); // 不能进行注册 ?必须应用DSL语法调用// spark.udf().register("MyAverge2", new MyAverage2()); List<User> userList = new ArrayList<>(); userList.add(new User("qiaozhanwei", 34)); userList.add(new User("zhangsan", 34)); Dataset<Row> df = spark.createDataFrame(userList, User.class); df.createOrReplaceTempView("user"); spark.sql("select myAverage1(age) from user").show(); MyAverage2 myAverage = new MyAverage2(); TypedColumn<User, Double> averageAge = myAverage.toColumn().name("average_age"); Encoder<User> userEncoder = Encoders.bean(User.class); Dataset<Double> average = df.as(userEncoder).select(averageAge); average.show(); spark.stop(); }}