关于spark:Spark-SQL-Java基础

35次阅读

共计 9320 个字符,预计需要花费 24 分钟才能阅读完成。

1、根底操作

package com.journey.sql;

import com.alibaba.fastjson.JSON;
import com.journey.sql.bean.User;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.apache.spark.sql.functions.col;

public class SparkSQLTest {public static void main(String[] args) throws Exception {
        SparkSession spark = SparkSession
                .builder()
                .appName("Demo")
                .master("local[*]")
                .getOrCreate();
        // 读取 json 文件 创立 DataFrame {"username": "lisi","age": 18},DataFrame 是一种非凡的 Dataset,行是 Row
        Dataset<Row> df = spark.read().json("datas/sql/user.json");
        // 展现表构造 + 数据
        df.show();
        // 打印 schema 构造
        df.printSchema();
        // 间接 select
        df.select("username").show();
        // 加 1
        df.select(col("username"), col("age").plus(1)).show();
        // 过滤 age 大于 19
        df.filter(col("age").gt(19)).show();
        // 统计 age 的个数
        df.groupBy("age").count().show();

        df.createOrReplaceTempView("user");

        // 应用 sql 来查问
        Dataset<Row> sqlDF = spark.sql("select * from user");
        sqlDF.show();

        // 注册 DataFrame 作为一个全局的长期视图
        df.createGlobalTempView("user2");

        spark.sql("select * from global_temp.user2").show();
        spark.newSession().sql("select * from global_temp.user2").show();


        /**
         * 数据集与 RDD 相似,然而,它们不应用 Java 序列化或 Kryo,而是应用专门的编码器来序列化对象以进行解决或通过网络传输。* 尽管编码器和规范序列化都负责将对象转换为字节,但编码器是动静生成的代码,并应用一种格局,容许 Spark 执行许多操作,* 如过滤、排序和散列,而无需将字节反序列化回对象。*/
        // 留神 : User 不能是 static 润饰
        User user = new User("qiaozhanwei", 20);
        Encoder<User> userEncoder = Encoders.bean(User.class);
        Dataset<User> javaBeanDS = spark.createDataset(Collections.singletonList(user), userEncoder);
        javaBeanDS.show();

        Encoder<Integer> integerEncoder = Encoders.INT();
        Dataset<Integer> primitiveDS = spark.createDataset(Arrays.asList(1, 2, 3), integerEncoder);
        Dataset<Integer> transformedDS = primitiveDS.map((MapFunction<Integer, Integer>) value -> value + 1,
                integerEncoder);
        // java: 不兼容的类型: java.lang.Object 无奈转换为 java.lang.Integer[],跑不通
        // Integer[] collect = transformedDS.collect();
        transformedDS.show();

        Dataset<User> userDS = spark.read().json("datas/sql/user.json").as(userEncoder);
        userDS.show();

        JavaRDD<User> userRDD = spark.read().textFile("datas/sql/user.json")
                .javaRDD()
                .map(line -> {User userInfo = JSON.parseObject(line, User.class);
                    return userInfo;
                });

        Dataset<Row> user3DF = spark.createDataFrame(userRDD, User.class);
        user3DF.createOrReplaceTempView("user3");

        List<User> userList = new ArrayList<>();
        userList.add(new User("haha", 30));

        Dataset<Row> dataFrame = spark.createDataFrame(userList, User.class);
        dataFrame.show();

        Dataset<Row> teenagerDF = spark.sql("select * from user3 where age between 13 and 20");

        Encoder<String> stringEncoder = Encoders.STRING();
        Dataset<String> teenagerNamesByIndexDF = teenagerDF.map(new MapFunction<Row, String>() {
            @Override
            public String call(Row value) throws Exception {return "Name :" + value.getString(1);
            }
        }, stringEncoder);
        teenagerNamesByIndexDF.show();

        Dataset<String> teenagerNamesByFieldDF = teenagerDF.map((MapFunction<Row, String>) row -> "Name:" + row.<String>getAs("userName"),
                stringEncoder);
        teenagerNamesByFieldDF.show();

        // 定义用户名字段类型
        StructField userNameField = DataTypes.createStructField("name", DataTypes.StringType, true);
        // 定义年龄字段类型
        StructField ageField = DataTypes.createStructField("age", DataTypes.IntegerType, true);
        List<StructField> fields = new ArrayList<>();
        fields.add(userNameField);
        fields.add(ageField);

        StructType schema = DataTypes.createStructType(fields);

        JavaRDD<String> user2RDD = spark.sparkContext().textFile("datas/sql/user.txt", 2).toJavaRDD();
        JavaRDD<Row> rowRDD = user2RDD.map(new Function<String, Row>() {
            @Override
            public Row call(String value) throws Exception {String[] fields = value.split(",");
                return RowFactory.create(fields[0], Integer.parseInt(fields[1]));
            }
        });

        Dataset<Row> user4DF = spark.createDataFrame(rowRDD, schema);
        user4DF.createOrReplaceTempView("user4");

        spark.sql("select * from user4").show();


        spark.stop();}
}

RDD、DataFrame 和 Dataset 的关系及转换

2、UDF 函数

标量函数

package com.journey.sql;

import com.journey.sql.bean.User;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;

import java.util.ArrayList;
import java.util.List;

public class ScalarFunctionTest {public static void main(String[] args) {

        SparkSession spark = SparkSession
                .builder()
                .appName("ScalarFunctionTest")
                .master("local[*]")
                .getOrCreate();

        // 依据参数有 UDF2....
        UDF1<String, String> myUdf = new UDF1<String, String>() {
            @Override
            public String call(String value) throws Exception {return "baidu-" + value;}
        };

        // 函数注册
        spark.udf().register("myUdf", myUdf, DataTypes.StringType);

        List<User> userList = new ArrayList<>();
        userList.add(new User("zhangsan", 20));

        Dataset<Row> df = spark.createDataFrame(userList, User.class);


        df.createOrReplaceTempView("user");

        spark.sql("select myUdf(userName) from user").show();

        spark.stop();}
}

聚合函数

弱类型

package com.journey.sql;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

// 无类型
public class MyAverage1 extends UserDefinedAggregateFunction {

    private StructType inputSchema;
    private StructType bufferSchema;

    public MyAverage1() {List<StructField> inputFields = new ArrayList<>();
        inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
        inputSchema = DataTypes.createStructType(inputFields);

        List<StructField> bufferFields = new ArrayList<>();
        bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
        bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
        bufferSchema = DataTypes.createStructType(bufferFields);

    }

    @Override
    public StructType inputSchema() {return inputSchema;}

    @Override
    public StructType bufferSchema() {return bufferSchema;}

    @Override
    public DataType dataType() {return DataTypes.DoubleType;}

    // 此函数是否总是在雷同的输出上返回雷同的输入
    @Override
    public boolean deterministic() {return true;}

    // 初始化
    @Override
    public void initialize(MutableAggregationBuffer buffer) {buffer.update(0, 0L);
        buffer.update(1, 0L);
    }

    // 中间状态更新
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {if (!input.isNullAt(0)) {long updatedSum = buffer.getLong(0) + input.getLong(0);
            long updatedCount = buffer.getLong(1) + 1;
            buffer.update(0, updatedSum);
            buffer.update(1, updatedCount);
        }
    }

    // 合并
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
        long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
        buffer1.update(0, mergedSum);
        buffer1.update(1, mergedCount);
    }

    // 计算结果
    @Override
    public Double evaluate(Row buffer) {return (double) buffer.getLong(0) / buffer.getLong(1);
    }
}

强类型

package com.journey.sql;

import java.io.Serializable;

public class Average implements Serializable {

    private long sum;
    private long count;

    public Average() {}

    public Average(long sum, long count) {
        this.sum = sum;
        this.count = count;
    }

    public long getSum() {return sum;}

    public void setSum(long sum) {this.sum = sum;}

    public long getCount() {return count;}

    public void setCount(long count) {this.count = count;}
}
package com.journey.sql;

import com.journey.sql.bean.User;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;


public class MyAverage2 extends Aggregator<User, Average, Double> {

    @Override
    public Average zero() {return new Average(0L, 0L);
    }

    @Override
    public Average reduce(Average buffer, User user) {long newSum = buffer.getSum() + user.getAge();
        long newCount = buffer.getCount() + 1;
        buffer.setSum(newSum);
        buffer.setCount(newCount);
        return buffer;
    }

    @Override
    public Average merge(Average b1, Average b2) {long mergedSum = b1.getSum() + b2.getSum();
        long mergedCount = b1.getCount() + b2.getCount();
        b1.setSum(mergedSum);
        b1.setCount(mergedCount);
        return b1;
    }

    @Override
    public Double finish(Average reduction) {return (double) reduction.getSum() / reduction.getCount();
    }

    @Override
    public Encoder<Average> bufferEncoder() {return Encoders.bean(Average.class);
    }

    @Override
    public Encoder<Double> outputEncoder() {return Encoders.DOUBLE();
    }
}
package com.journey.sql;

import com.journey.sql.bean.User;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;

import java.util.ArrayList;
import java.util.List;

public class AggregationsTest {public static void main(String[] args) throws Exception {
        SparkSession spark = SparkSession
                .builder()
                .appName("AggregationsTest")
                .master("local[*]")
                .getOrCreate();

        spark.udf().register("myAverage1", new MyAverage1());
        // 不能进行注册?必须应用 DSL 语法调用
//        spark.udf().register("MyAverge2", new MyAverage2());

        List<User> userList = new ArrayList<>();
        userList.add(new User("qiaozhanwei", 34));
        userList.add(new User("zhangsan", 34));

        Dataset<Row> df = spark.createDataFrame(userList, User.class);
        df.createOrReplaceTempView("user");

        spark.sql("select myAverage1(age) from user").show();

        MyAverage2 myAverage = new MyAverage2();
        TypedColumn<User, Double> averageAge = myAverage.toColumn().name("average_age");
        Encoder<User> userEncoder = Encoders.bean(User.class);
        Dataset<Double> average = df.as(userEncoder).select(averageAge);
        average.show();

        spark.stop();}
}

正文完
 0