1、根底操作

package com.journey.sql;import com.alibaba.fastjson.JSON;import com.journey.sql.bean.User;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.function.Function;import org.apache.spark.api.java.function.MapFunction;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Encoder;import org.apache.spark.sql.Encoders;import org.apache.spark.sql.Row;import org.apache.spark.sql.RowFactory;import org.apache.spark.sql.SparkSession;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.Arrays;import java.util.Collections;import java.util.List;import static org.apache.spark.sql.functions.col;public class SparkSQLTest {    public static void main(String[] args) throws Exception {        SparkSession spark = SparkSession                .builder()                .appName("Demo")                .master("local[*]")                .getOrCreate();        // 读取 json 文件 创立 DataFrame {"username": "lisi","age": 18},DataFrame是一种非凡的Dataset,行是Row        Dataset<Row> df = spark.read().json("datas/sql/user.json");        // 展现表构造 + 数据        df.show();        // 打印schema构造        df.printSchema();        // 间接select        df.select("username").show();        // 加1        df.select(col("username"), col("age").plus(1)).show();        // 过滤age大于19        df.filter(col("age").gt(19)).show();        // 统计age的个数        df.groupBy("age").count().show();        df.createOrReplaceTempView("user");        // 应用sql来查问        Dataset<Row> sqlDF = spark.sql("select * from user");        sqlDF.show();        // 注册DataFrame作为一个全局的长期视图        df.createGlobalTempView("user2");        spark.sql("select * from global_temp.user2").show();        spark.newSession().sql("select * from global_temp.user2").show();        /**         * 数据集与 RDD 相似,然而,它们不应用 Java 序列化或 Kryo,而是应用专门的编码器来序列化对象以进行解决或通过网络传输。         * 尽管编码器和规范序列化都负责将对象转换为字节,但编码器是动静生成的代码,并应用一种格局,容许 Spark 执行许多操作,         * 如过滤、排序和散列,而无需将字节反序列化回对象。         */        // 留神 : User不能是static润饰        User user = new User("qiaozhanwei", 20);        Encoder<User> userEncoder = Encoders.bean(User.class);        Dataset<User> javaBeanDS = spark.createDataset(Collections.singletonList(user), userEncoder);        javaBeanDS.show();        Encoder<Integer> integerEncoder = Encoders.INT();        Dataset<Integer> primitiveDS = spark.createDataset(Arrays.asList(1, 2, 3), integerEncoder);        Dataset<Integer> transformedDS = primitiveDS.map(                (MapFunction<Integer, Integer>) value -> value + 1,                integerEncoder);        // java: 不兼容的类型: java.lang.Object无奈转换为java.lang.Integer[],跑不通        // Integer[] collect = transformedDS.collect();        transformedDS.show();        Dataset<User> userDS = spark.read().json("datas/sql/user.json").as(userEncoder);        userDS.show();        JavaRDD<User> userRDD = spark.read().textFile("datas/sql/user.json")                .javaRDD()                .map(line -> {                    User userInfo = JSON.parseObject(line, User.class);                    return userInfo;                });        Dataset<Row> user3DF = spark.createDataFrame(userRDD, User.class);        user3DF.createOrReplaceTempView("user3");        List<User> userList = new ArrayList<>();        userList.add(new User("haha", 30));        Dataset<Row> dataFrame = spark.createDataFrame(userList, User.class);        dataFrame.show();        Dataset<Row> teenagerDF = spark.sql("select * from user3 where age between 13 and 20");        Encoder<String> stringEncoder = Encoders.STRING();        Dataset<String> teenagerNamesByIndexDF = teenagerDF.map(new MapFunction<Row, String>() {            @Override            public String call(Row value) throws Exception {                return "Name : " + value.getString(1);            }        }, stringEncoder);        teenagerNamesByIndexDF.show();        Dataset<String> teenagerNamesByFieldDF = teenagerDF.map(                (MapFunction<Row, String>) row -> "Name: " + row.<String>getAs("userName"),                stringEncoder);        teenagerNamesByFieldDF.show();        // 定义用户名字段类型        StructField userNameField = DataTypes.createStructField("name", DataTypes.StringType, true);        // 定义年龄字段类型        StructField ageField = DataTypes.createStructField("age", DataTypes.IntegerType, true);        List<StructField> fields = new ArrayList<>();        fields.add(userNameField);        fields.add(ageField);        StructType schema = DataTypes.createStructType(fields);        JavaRDD<String> user2RDD = spark.sparkContext().textFile("datas/sql/user.txt", 2).toJavaRDD();        JavaRDD<Row> rowRDD = user2RDD.map(new Function<String, Row>() {            @Override            public Row call(String value) throws Exception {                String[] fields = value.split(",");                return RowFactory.create(fields[0], Integer.parseInt(fields[1]));            }        });        Dataset<Row> user4DF = spark.createDataFrame(rowRDD, schema);        user4DF.createOrReplaceTempView("user4");        spark.sql("select * from user4").show();        spark.stop();    }}

RDD、DataFrame和Dataset的关系及转换

2、UDF函数

标量函数

package com.journey.sql;import com.journey.sql.bean.User;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.SparkSession;import org.apache.spark.sql.api.java.UDF1;import org.apache.spark.sql.types.DataTypes;import java.util.ArrayList;import java.util.List;public class ScalarFunctionTest {    public static void main(String[] args) {        SparkSession spark = SparkSession                .builder()                .appName("ScalarFunctionTest")                .master("local[*]")                .getOrCreate();        // 依据参数有UDF2....        UDF1<String, String> myUdf = new UDF1<String, String>() {            @Override            public String call(String value) throws Exception {                return "baidu-" + value;            }        };        // 函数注册        spark.udf().register("myUdf", myUdf, DataTypes.StringType);        List<User> userList = new ArrayList<>();        userList.add(new User("zhangsan", 20));        Dataset<Row> df = spark.createDataFrame(userList, User.class);        df.createOrReplaceTempView("user");        spark.sql("select myUdf(userName) from user").show();        spark.stop();    }}

聚合函数

弱类型

package com.journey.sql;import org.apache.spark.sql.Row;import org.apache.spark.sql.expressions.MutableAggregationBuffer;import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;import org.apache.spark.sql.types.DataType;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.List;// 无类型public class MyAverage1 extends UserDefinedAggregateFunction {    private StructType inputSchema;    private StructType bufferSchema;    public MyAverage1() {        List<StructField> inputFields = new ArrayList<>();        inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));        inputSchema = DataTypes.createStructType(inputFields);        List<StructField> bufferFields = new ArrayList<>();        bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));        bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));        bufferSchema = DataTypes.createStructType(bufferFields);    }    @Override    public StructType inputSchema() {        return inputSchema;    }    @Override    public StructType bufferSchema() {        return bufferSchema;    }    @Override    public DataType dataType() {        return DataTypes.DoubleType;    }    // 此函数是否总是在雷同的输出上返回雷同的输入    @Override    public boolean deterministic() {        return true;    }    // 初始化    @Override    public void initialize(MutableAggregationBuffer buffer) {        buffer.update(0, 0L);        buffer.update(1, 0L);    }    // 中间状态更新    @Override    public void update(MutableAggregationBuffer buffer, Row input) {        if (!input.isNullAt(0)) {            long updatedSum = buffer.getLong(0) + input.getLong(0);            long updatedCount = buffer.getLong(1) + 1;            buffer.update(0, updatedSum);            buffer.update(1, updatedCount);        }    }    // 合并    @Override    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {        long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);        long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);        buffer1.update(0, mergedSum);        buffer1.update(1, mergedCount);    }    // 计算结果    @Override    public Double evaluate(Row buffer) {        return (double) buffer.getLong(0) / buffer.getLong(1);    }}

强类型

package com.journey.sql;import java.io.Serializable;public class Average implements Serializable {    private long sum;    private long count;    public Average() {}    public Average(long sum, long count) {        this.sum = sum;        this.count = count;    }    public long getSum() {        return sum;    }    public void setSum(long sum) {        this.sum = sum;    }    public long getCount() {        return count;    }    public void setCount(long count) {        this.count = count;    }}
package com.journey.sql;import com.journey.sql.bean.User;import org.apache.spark.sql.Encoder;import org.apache.spark.sql.Encoders;import org.apache.spark.sql.expressions.Aggregator;public class MyAverage2 extends Aggregator<User, Average, Double> {    @Override    public Average zero() {        return new Average(0L, 0L);    }    @Override    public Average reduce(Average buffer, User user) {        long newSum = buffer.getSum() + user.getAge();        long newCount = buffer.getCount() + 1;        buffer.setSum(newSum);        buffer.setCount(newCount);        return buffer;    }    @Override    public Average merge(Average b1, Average b2) {        long mergedSum = b1.getSum() + b2.getSum();        long mergedCount = b1.getCount() + b2.getCount();        b1.setSum(mergedSum);        b1.setCount(mergedCount);        return b1;    }    @Override    public Double finish(Average reduction) {        return (double) reduction.getSum() / reduction.getCount();    }    @Override    public Encoder<Average> bufferEncoder() {        return Encoders.bean(Average.class);    }    @Override    public Encoder<Double> outputEncoder() {        return Encoders.DOUBLE();    }}
package com.journey.sql;import com.journey.sql.bean.User;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Encoder;import org.apache.spark.sql.Encoders;import org.apache.spark.sql.Row;import org.apache.spark.sql.SparkSession;import org.apache.spark.sql.TypedColumn;import java.util.ArrayList;import java.util.List;public class AggregationsTest {    public static void main(String[] args) throws Exception {        SparkSession spark = SparkSession                .builder()                .appName("AggregationsTest")                .master("local[*]")                .getOrCreate();        spark.udf().register("myAverage1", new MyAverage1());        // 不能进行注册 ?必须应用DSL语法调用//        spark.udf().register("MyAverge2", new MyAverage2());        List<User> userList = new ArrayList<>();        userList.add(new User("qiaozhanwei", 34));        userList.add(new User("zhangsan", 34));        Dataset<Row> df = spark.createDataFrame(userList, User.class);        df.createOrReplaceTempView("user");        spark.sql("select myAverage1(age) from user").show();        MyAverage2 myAverage = new MyAverage2();        TypedColumn<User, Double> averageAge = myAverage.toColumn().name("average_age");        Encoder<User> userEncoder = Encoders.bean(User.class);        Dataset<Double> average = df.as(userEncoder).select(averageAge);        average.show();        spark.stop();    }}