共计 9320 个字符,预计需要花费 24 分钟才能阅读完成。
1、根底操作
package com.journey.sql;
import com.alibaba.fastjson.JSON;
import com.journey.sql.bean.User;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.apache.spark.sql.functions.col;
public class SparkSQLTest {public static void main(String[] args) throws Exception {
SparkSession spark = SparkSession
.builder()
.appName("Demo")
.master("local[*]")
.getOrCreate();
// 读取 json 文件 创立 DataFrame {"username": "lisi","age": 18},DataFrame 是一种非凡的 Dataset,行是 Row
Dataset<Row> df = spark.read().json("datas/sql/user.json");
// 展现表构造 + 数据
df.show();
// 打印 schema 构造
df.printSchema();
// 间接 select
df.select("username").show();
// 加 1
df.select(col("username"), col("age").plus(1)).show();
// 过滤 age 大于 19
df.filter(col("age").gt(19)).show();
// 统计 age 的个数
df.groupBy("age").count().show();
df.createOrReplaceTempView("user");
// 应用 sql 来查问
Dataset<Row> sqlDF = spark.sql("select * from user");
sqlDF.show();
// 注册 DataFrame 作为一个全局的长期视图
df.createGlobalTempView("user2");
spark.sql("select * from global_temp.user2").show();
spark.newSession().sql("select * from global_temp.user2").show();
/**
* 数据集与 RDD 相似,然而,它们不应用 Java 序列化或 Kryo,而是应用专门的编码器来序列化对象以进行解决或通过网络传输。* 尽管编码器和规范序列化都负责将对象转换为字节,但编码器是动静生成的代码,并应用一种格局,容许 Spark 执行许多操作,* 如过滤、排序和散列,而无需将字节反序列化回对象。*/
// 留神 : User 不能是 static 润饰
User user = new User("qiaozhanwei", 20);
Encoder<User> userEncoder = Encoders.bean(User.class);
Dataset<User> javaBeanDS = spark.createDataset(Collections.singletonList(user), userEncoder);
javaBeanDS.show();
Encoder<Integer> integerEncoder = Encoders.INT();
Dataset<Integer> primitiveDS = spark.createDataset(Arrays.asList(1, 2, 3), integerEncoder);
Dataset<Integer> transformedDS = primitiveDS.map((MapFunction<Integer, Integer>) value -> value + 1,
integerEncoder);
// java: 不兼容的类型: java.lang.Object 无奈转换为 java.lang.Integer[],跑不通
// Integer[] collect = transformedDS.collect();
transformedDS.show();
Dataset<User> userDS = spark.read().json("datas/sql/user.json").as(userEncoder);
userDS.show();
JavaRDD<User> userRDD = spark.read().textFile("datas/sql/user.json")
.javaRDD()
.map(line -> {User userInfo = JSON.parseObject(line, User.class);
return userInfo;
});
Dataset<Row> user3DF = spark.createDataFrame(userRDD, User.class);
user3DF.createOrReplaceTempView("user3");
List<User> userList = new ArrayList<>();
userList.add(new User("haha", 30));
Dataset<Row> dataFrame = spark.createDataFrame(userList, User.class);
dataFrame.show();
Dataset<Row> teenagerDF = spark.sql("select * from user3 where age between 13 and 20");
Encoder<String> stringEncoder = Encoders.STRING();
Dataset<String> teenagerNamesByIndexDF = teenagerDF.map(new MapFunction<Row, String>() {
@Override
public String call(Row value) throws Exception {return "Name :" + value.getString(1);
}
}, stringEncoder);
teenagerNamesByIndexDF.show();
Dataset<String> teenagerNamesByFieldDF = teenagerDF.map((MapFunction<Row, String>) row -> "Name:" + row.<String>getAs("userName"),
stringEncoder);
teenagerNamesByFieldDF.show();
// 定义用户名字段类型
StructField userNameField = DataTypes.createStructField("name", DataTypes.StringType, true);
// 定义年龄字段类型
StructField ageField = DataTypes.createStructField("age", DataTypes.IntegerType, true);
List<StructField> fields = new ArrayList<>();
fields.add(userNameField);
fields.add(ageField);
StructType schema = DataTypes.createStructType(fields);
JavaRDD<String> user2RDD = spark.sparkContext().textFile("datas/sql/user.txt", 2).toJavaRDD();
JavaRDD<Row> rowRDD = user2RDD.map(new Function<String, Row>() {
@Override
public Row call(String value) throws Exception {String[] fields = value.split(",");
return RowFactory.create(fields[0], Integer.parseInt(fields[1]));
}
});
Dataset<Row> user4DF = spark.createDataFrame(rowRDD, schema);
user4DF.createOrReplaceTempView("user4");
spark.sql("select * from user4").show();
spark.stop();}
}
RDD、DataFrame 和 Dataset 的关系及转换
2、UDF 函数
标量函数
package com.journey.sql;
import com.journey.sql.bean.User;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import java.util.ArrayList;
import java.util.List;
public class ScalarFunctionTest {public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("ScalarFunctionTest")
.master("local[*]")
.getOrCreate();
// 依据参数有 UDF2....
UDF1<String, String> myUdf = new UDF1<String, String>() {
@Override
public String call(String value) throws Exception {return "baidu-" + value;}
};
// 函数注册
spark.udf().register("myUdf", myUdf, DataTypes.StringType);
List<User> userList = new ArrayList<>();
userList.add(new User("zhangsan", 20));
Dataset<Row> df = spark.createDataFrame(userList, User.class);
df.createOrReplaceTempView("user");
spark.sql("select myUdf(userName) from user").show();
spark.stop();}
}
聚合函数
弱类型
package com.journey.sql;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.List;
// 无类型
public class MyAverage1 extends UserDefinedAggregateFunction {
private StructType inputSchema;
private StructType bufferSchema;
public MyAverage1() {List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
inputSchema = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
@Override
public StructType inputSchema() {return inputSchema;}
@Override
public StructType bufferSchema() {return bufferSchema;}
@Override
public DataType dataType() {return DataTypes.DoubleType;}
// 此函数是否总是在雷同的输出上返回雷同的输入
@Override
public boolean deterministic() {return true;}
// 初始化
@Override
public void initialize(MutableAggregationBuffer buffer) {buffer.update(0, 0L);
buffer.update(1, 0L);
}
// 中间状态更新
@Override
public void update(MutableAggregationBuffer buffer, Row input) {if (!input.isNullAt(0)) {long updatedSum = buffer.getLong(0) + input.getLong(0);
long updatedCount = buffer.getLong(1) + 1;
buffer.update(0, updatedSum);
buffer.update(1, updatedCount);
}
}
// 合并
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
buffer1.update(0, mergedSum);
buffer1.update(1, mergedCount);
}
// 计算结果
@Override
public Double evaluate(Row buffer) {return (double) buffer.getLong(0) / buffer.getLong(1);
}
}
强类型
package com.journey.sql;
import java.io.Serializable;
public class Average implements Serializable {
private long sum;
private long count;
public Average() {}
public Average(long sum, long count) {
this.sum = sum;
this.count = count;
}
public long getSum() {return sum;}
public void setSum(long sum) {this.sum = sum;}
public long getCount() {return count;}
public void setCount(long count) {this.count = count;}
}
package com.journey.sql;
import com.journey.sql.bean.User;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;
public class MyAverage2 extends Aggregator<User, Average, Double> {
@Override
public Average zero() {return new Average(0L, 0L);
}
@Override
public Average reduce(Average buffer, User user) {long newSum = buffer.getSum() + user.getAge();
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
@Override
public Average merge(Average b1, Average b2) {long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
}
@Override
public Double finish(Average reduction) {return (double) reduction.getSum() / reduction.getCount();
}
@Override
public Encoder<Average> bufferEncoder() {return Encoders.bean(Average.class);
}
@Override
public Encoder<Double> outputEncoder() {return Encoders.DOUBLE();
}
}
package com.journey.sql;
import com.journey.sql.bean.User;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;
import java.util.ArrayList;
import java.util.List;
public class AggregationsTest {public static void main(String[] args) throws Exception {
SparkSession spark = SparkSession
.builder()
.appName("AggregationsTest")
.master("local[*]")
.getOrCreate();
spark.udf().register("myAverage1", new MyAverage1());
// 不能进行注册?必须应用 DSL 语法调用
// spark.udf().register("MyAverge2", new MyAverage2());
List<User> userList = new ArrayList<>();
userList.add(new User("qiaozhanwei", 34));
userList.add(new User("zhangsan", 34));
Dataset<Row> df = spark.createDataFrame(userList, User.class);
df.createOrReplaceTempView("user");
spark.sql("select myAverage1(age) from user").show();
MyAverage2 myAverage = new MyAverage2();
TypedColumn<User, Double> averageAge = myAverage.toColumn().name("average_age");
Encoder<User> userEncoder = Encoders.bean(User.class);
Dataset<Double> average = df.as(userEncoder).select(averageAge);
average.show();
spark.stop();}
}
正文完