聊聊flink的ParallelIteratorInputFormat

7次阅读

共计 10914 个字符,预计需要花费 28 分钟才能阅读完成。


本文主要研究一下 flink 的 ParallelIteratorInputFormat
实例
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Long> dataSet = env.generateSequence(15,106)
.setParallelism(3);
dataSet.print();
这里使用 ExecutionEnvironment 的 generateSequence 方法创建了带 NumberSequenceIterator 的 ParallelIteratorInputFormat
ParallelIteratorInputFormat
flink-java-1.6.2-sources.jar!/org/apache/flink/api/java/io/ParallelIteratorInputFormat.java
/**
* An input format that generates data in parallel through a {@link SplittableIterator}.
*/
@PublicEvolving
public class ParallelIteratorInputFormat<T> extends GenericInputFormat<T> {

private static final long serialVersionUID = 1L;

private final SplittableIterator<T> source;

private transient Iterator<T> splitIterator;

public ParallelIteratorInputFormat(SplittableIterator<T> iterator) {
this.source = iterator;
}

@Override
public void open(GenericInputSplit split) throws IOException {
super.open(split);

this.splitIterator = this.source.getSplit(split.getSplitNumber(), split.getTotalNumberOfSplits());
}

@Override
public boolean reachedEnd() {
return !this.splitIterator.hasNext();
}

@Override
public T nextRecord(T reuse) {
return this.splitIterator.next();
}
}
ParallelIteratorInputFormat 继承了 GenericInputFormat 类,而 GenericInputFormat 类底下还有其他四个子类,分别是 CRowValuesInputFormat、CollectionInputFormat、IteratorInputFormat、ValuesInputFormat,它们有一个共同的特点就是都实现了 NonParallelInput 接口
NonParallelInput
flink-core-1.6.2-sources.jar!/org/apache/flink/api/common/io/NonParallelInput.java
/**
* This interface acts as a marker for input formats for inputs which cannot be split.
* Data sources with a non-parallel input formats are always executed with a parallelism
* of one.
*
* @see InputFormat
*/
@Public
public interface NonParallelInput {
}
这个接口没有定义任何方法,仅仅是一个标识,表示该 InputFormat 是否支持 split
GenericInputFormat.createInputSplits
flink-core-1.6.2-sources.jar!/org/apache/flink/api/common/io/GenericInputFormat.java
@Override
public GenericInputSplit[] createInputSplits(int numSplits) throws IOException {
if (numSplits < 1) {
throw new IllegalArgumentException(“Number of input splits has to be at least 1.”);
}

numSplits = (this instanceof NonParallelInput) ? 1 : numSplits;
GenericInputSplit[] splits = new GenericInputSplit[numSplits];
for (int i = 0; i < splits.length; i++) {
splits[i] = new GenericInputSplit(i, numSplits);
}
return splits;
}
GenericInputFormat 的 createInputSplits 方法对输入的 numSplits 进行了限制,如果小于 1 则抛出 IllegalArgumentException 异常,如果当前 InputFormat 有实现 NonParallelInput 接口,则将 numSplits 重置为 1
ExecutionEnvironment.fromParallelCollection
flink-java-1.6.2-sources.jar!/org/apache/flink/api/java/ExecutionEnvironment.java
/**
* Creates a new data set that contains elements in the iterator. The iterator is splittable, allowing the
* framework to create a parallel data source that returns the elements in the iterator.
*
* <p>Because the iterator will remain unmodified until the actual execution happens, the type of data
* returned by the iterator must be given explicitly in the form of the type class (this is due to the
* fact that the Java compiler erases the generic type information).
*
* @param iterator The iterator that produces the elements of the data set.
* @param type The class of the data produced by the iterator. Must not be a generic class.
* @return A DataSet representing the elements in the iterator.
*
* @see #fromParallelCollection(SplittableIterator, TypeInformation)
*/
public <X> DataSource<X> fromParallelCollection(SplittableIterator<X> iterator, Class<X> type) {
return fromParallelCollection(iterator, TypeExtractor.getForClass(type));
}

/**
* Creates a new data set that contains elements in the iterator. The iterator is splittable, allowing the
* framework to create a parallel data source that returns the elements in the iterator.
*
* <p>Because the iterator will remain unmodified until the actual execution happens, the type of data
* returned by the iterator must be given explicitly in the form of the type information.
* This method is useful for cases where the type is generic. In that case, the type class
* (as given in {@link #fromParallelCollection(SplittableIterator, Class)} does not supply all type information.
*
* @param iterator The iterator that produces the elements of the data set.
* @param type The TypeInformation for the produced data set.
* @return A DataSet representing the elements in the iterator.
*
* @see #fromParallelCollection(SplittableIterator, Class)
*/
public <X> DataSource<X> fromParallelCollection(SplittableIterator<X> iterator, TypeInformation<X> type) {
return fromParallelCollection(iterator, type, Utils.getCallLocationName());
}

// private helper for passing different call location names
private <X> DataSource<X> fromParallelCollection(SplittableIterator<X> iterator, TypeInformation<X> type, String callLocationName) {
return new DataSource<>(this, new ParallelIteratorInputFormat<>(iterator), type, callLocationName);
}

/**
* Creates a new data set that contains a sequence of numbers. The data set will be created in parallel,
* so there is no guarantee about the order of the elements.
*
* @param from The number to start at (inclusive).
* @param to The number to stop at (inclusive).
* @return A DataSet, containing all number in the {@code [from, to]} interval.
*/
public DataSource<Long> generateSequence(long from, long to) {
return fromParallelCollection(new NumberSequenceIterator(from, to), BasicTypeInfo.LONG_TYPE_INFO, Utils.getCallLocationName());
}
ExecutionEnvironment 的 fromParallelCollection 方法,针对 SplittableIterator 类型的 iterator,会创建 ParallelIteratorInputFormat;generateSequence 方法也调用了 fromParallelCollection 方法,它创建的是 NumberSequenceIterator(是 SplittableIterator 的子类)
SplittableIterator
flink-core-1.6.2-sources.jar!/org/apache/flink/util/SplittableIterator.java
/**
* Abstract base class for iterators that can split themselves into multiple disjoint
* iterators. The union of these iterators returns the original iterator values.
*
* @param <T> The type of elements returned by the iterator.
*/
@Public
public abstract class SplittableIterator<T> implements Iterator<T>, Serializable {

private static final long serialVersionUID = 200377674313072307L;

/**
* Splits this iterator into a number disjoint iterators.
* The union of these iterators returns the original iterator values.
*
* @param numPartitions The number of iterators to split into.
* @return An array with the split iterators.
*/
public abstract Iterator<T>[] split(int numPartitions);

/**
* Splits this iterator into <i>n</i> partitions and returns the <i>i-th</i> partition
* out of those.
*
* @param num The partition to return (<i>i</i>).
* @param numPartitions The number of partitions to split into (<i>n</i>).
* @return The iterator for the partition.
*/
public Iterator<T> getSplit(int num, int numPartitions) {
if (numPartitions < 1 || num < 0 || num >= numPartitions) {
throw new IllegalArgumentException();
}

return split(numPartitions)[num];
}

/**
* The maximum number of splits into which this iterator can be split up.
*
* @return The maximum number of splits into which this iterator can be split up.
*/
public abstract int getMaximumNumberOfSplits();
}
SplittableIterator 是个抽象类,它定义了抽象方法 split 以及 getMaximumNumberOfSplits;它有两个实现类,分别是 LongValueSequenceIterator 以及 NumberSequenceIterator,这里我们看下 NumberSequenceIterator
NumberSequenceIterator
flink-core-1.6.2-sources.jar!/org/apache/flink/util/NumberSequenceIterator.java
/**
* The {@code NumberSequenceIterator} is an iterator that returns a sequence of numbers (as {@code Long})s.
* The iterator is splittable (as defined by {@link SplittableIterator}, i.e., it can be divided into multiple
* iterators that each return a subsequence of the number sequence.
*/
@Public
public class NumberSequenceIterator extends SplittableIterator<Long> {

private static final long serialVersionUID = 1L;

/** The last number returned by the iterator. */
private final long to;

/** The next number to be returned. */
private long current;

/**
* Creates a new splittable iterator, returning the range [from, to].
* Both boundaries of the interval are inclusive.
*
* @param from The first number returned by the iterator.
* @param to The last number returned by the iterator.
*/
public NumberSequenceIterator(long from, long to) {
if (from > to) {
throw new IllegalArgumentException(“The ‘to’ value must not be smaller than the ‘from’ value.”);
}

this.current = from;
this.to = to;
}

@Override
public boolean hasNext() {
return current <= to;
}

@Override
public Long next() {
if (current <= to) {
return current++;
} else {
throw new NoSuchElementException();
}
}

@Override
public NumberSequenceIterator[] split(int numPartitions) {
if (numPartitions < 1) {
throw new IllegalArgumentException(“The number of partitions must be at least 1.”);
}

if (numPartitions == 1) {
return new NumberSequenceIterator[] { new NumberSequenceIterator(current, to) };
}

// here, numPartitions >= 2 !!!

long elementsPerSplit;

if (to – current + 1 >= 0) {
elementsPerSplit = (to – current + 1) / numPartitions;
}
else {
// long overflow of the range.
// we compute based on half the distance, to prevent the overflow.
// in most cases it holds that: current < 0 and to > 0, except for: to == 0 and current == Long.MIN_VALUE
// the later needs a special case
final long halfDiff; // must be positive

if (current == Long.MIN_VALUE) {
// this means to >= 0
halfDiff = (Long.MAX_VALUE / 2 + 1) + to / 2;
} else {
long posFrom = -current;
if (posFrom > to) {
halfDiff = to + ((posFrom – to) / 2);
} else {
halfDiff = posFrom + ((to – posFrom) / 2);
}
}
elementsPerSplit = halfDiff / numPartitions * 2;
}

if (elementsPerSplit < Long.MAX_VALUE) {
// figure out how many get one in addition
long numWithExtra = -(elementsPerSplit * numPartitions) + to – current + 1;

// based on rounding errors, we may have lost one)
if (numWithExtra > numPartitions) {
elementsPerSplit++;
numWithExtra -= numPartitions;

if (numWithExtra > numPartitions) {
throw new RuntimeException(“Bug in splitting logic. To much rounding loss.”);
}
}

NumberSequenceIterator[] iters = new NumberSequenceIterator[numPartitions];
long curr = current;
int i = 0;
for (; i < numWithExtra; i++) {
long next = curr + elementsPerSplit + 1;
iters[i] = new NumberSequenceIterator(curr, next – 1);
curr = next;
}
for (; i < numPartitions; i++) {
long next = curr + elementsPerSplit;
iters[i] = new NumberSequenceIterator(curr, next – 1, true);
curr = next;
}

return iters;
}
else {
// this can only be the case when there are two partitions
if (numPartitions != 2) {
throw new RuntimeException(“Bug in splitting logic.”);
}

return new NumberSequenceIterator[] {
new NumberSequenceIterator(current, current + elementsPerSplit),
new NumberSequenceIterator(current + elementsPerSplit, to)
};
}
}

@Override
public int getMaximumNumberOfSplits() {
if (to >= Integer.MAX_VALUE || current <= Integer.MIN_VALUE || to – current + 1 >= Integer.MAX_VALUE) {
return Integer.MAX_VALUE;
}
else {
return (int) (to – current + 1);
}
}

//……
}

NumberSequenceIterator 的构造器提供了 from 及 to 两个参数,它内部有一个 current 值,初始的时候等于 from
split 方法首先根据 numPartitions,来计算 elementsPerSplit,当 to – current + 1 >= 0 时,计算公式为 (to – current + 1) / numPartitions
之后根据计算出来的 elementsPerSplit 来计算 numWithExtra,这是因为计算 elementsPerSplit 的时候用的是取整操作,如果每一批都按 elementsPerSplit,可能存在多余的,于是就算出这个多余的 numWithExtra,如果它大于 numPartitions,则对 elementsPerSplit 增加 1,然后对 numWithExtra 减去 numPartitions
最后就是先根据 numWithExtra 来循环分配前 numWithExtra 个批次,将多余的 numWithExtra 平均分配给前 numWithExtra 个批次;numWithExtra 之后到 numPartitions 的批次,就正常的使用 from + elementsPerSplit - 1 来计算 to
getMaximumNumberOfSplits 则是返回可以 split 的最大数量,(to >= Integer.MAX_VALUE || current <= Integer.MIN_VALUE || to – current + 1 >= Integer.MAX_VALUE) 的条件下返回 Integer.MAX_VALUE,否则返回 (int) (to – current + 1)

小结

GenericInputFormat 类底下有五个子类,除了 ParallelIteratorInputFormat 外,其他的分别是 CRowValuesInputFormat、CollectionInputFormat、IteratorInputFormat、ValuesInputFormat,后面这四个子类有一个共同的特点就是都实现了 NonParallelInput 接口
GenericInputFormat 的 createInputSplits 会对输入的 numSplits 进行限制,如果是 NonParallelInput 类型的,则强制重置为 1
NumberSequenceIterator 是 SplittableIterator 的一个实现类,在 ExecutionEnvironment 的 fromParallelCollection 方法,generateSequence 方法 (它创建的是 NumberSequenceIterator),针对 SplittableIterator 类型的 iterator,创建 ParallelIteratorInputFormat;而 NumberSequenceIterator 的 split 方法,它先计算 elementsPerSplit,然后计算 numWithExtra,把 numWithExtra 均分到前面几个批次,最后在按 elementsPerSplit 均分剩余的批次

doc

ParallelIteratorInputFormat
SplittableIterator
NumberSequenceIterator

正文完
 0