共计 10225 个字符,预计需要花费 26 分钟才能阅读完成。
RPC 客户端工厂 TransportClientFactory
TransportClientFactory 是创建 TransportClient 的工厂类。TransportContext 的 createClientFactory 方法可以创建 TransportClientFactory 的实例
/**
* Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning
* a new Client. Bootstraps will be executed synchronously, and must run successfully in order
* to create a Client.
*/
public TransportClientFactory createClientFactory(List<TransportClientBootstrap> bootstraps) {return new TransportClientFactory(this, bootstraps);
}
public TransportClientFactory createClientFactory() {return createClientFactory(Lists.<TransportClientBootstrap>newArrayList());
}
可以看到,TransportContext 中有两个重载的 createClientFactory 方法,它们最终在构造 TransportClientFactory 时都会传递两个参数:TransportContext 和 TransportClientBootstrap 列表。TransportClientFactory 构造器的实现如代码所示。
public TransportClientFactory(
TransportContext context,
List<TransportClientBootstrap> clientBootstraps) {this.context = Preconditions.checkNotNull(context);
this.conf = context.getConf();
this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
this.connectionPool = new ConcurrentHashMap<>();
this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
this.rand = new Random();
IOMode ioMode = IOMode.valueOf(conf.ioMode());
this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
// TODO: Make thread pool name configurable.
this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client");
this.pooledAllocator = NettyUtils.createPooledByteBufAllocator(conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads());
}
TransportClientFactory 构造器中的各个变量如下:
context:参数传递的 TransportContext 的引用。
conf:指 TransportConf,这里通过调用 TransportContext 的 getConf 获取。
clientBootstraps:参数传递的 TransportClientBootstrap 列表。
connectionPool:针对每个 Socket 地址的连接池 ClientPool 的缓存。
connectionPool 的数据结构较为复杂,为便于读者理解,这里以图来表示 connectionPool 的数据结构。
numConnectionsPerPeer:从 TransportConf 获 取 的 key 为“spark.+ 模 块 名 +.io.num-ConnectionsPerPeer”的属性值。此属性值用于指定对等节点间的连接数。这里的模块名实际为 TransportConf 的 module 字段。Spark 的很多组件都利用 RPC 框架构建,它们之间按照模块名区分,例如,RPC 模块的 key 为“spark.rpc.io.numConnectionsPerPeer”。
#TransportConf 中的 getConfKey 方法获取参数
private String getConfKey(String suffix) {return "spark." + module + "." + suffix;}
rand:对 Socket 地址对应的连接池 ClientPool 中缓存的
TransportClient 进行随机选择,对每个连接做负载均衡。
ioMode:IO 模式,即从 TransportConf 获取 key 为“spark.+ 模块名 +.io.mode”的属性值。默认值为 NIO, Spark 还支持 EPOLL。
socketChannelClass:客户端 Channel 被创建时使用的类,通过 ioMode 来匹配,默认为 NioSocketChannel, Spark 还支持 EpollEventLoopGroup。
workerGroup:根据 Netty 的规范,客户端只有 worker 组,所以此处创建 worker-Group。workerGroup 的实际类型是 NioEventLoopGroup。
pooledAllocator:汇集 ByteBuf 但对本地线程缓存禁用的分配器。
客户端引导程序 TransportClientBootstrap
TransportClientFactory 的 clientBootstraps 属性是 TransportClientBootstrap 的列表。Transport ClientBootstrap 是在 TransportClient 上执行的客户端引导程序,主要对连接建立时进行一些初始化的准备(例如验证、加密)。TransportClientBootstrap 所做的操作往往是昂贵的,好在建立的连接可以重用。TransportClientBootstrap 的接口定义如代码清单 3 -10 所示:
import io.netty.channel.Channel;
/**
* A bootstrap which is executed on a TransportClient before it is returned to the user.
* This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
* connection basis.
*
* Since connections (and TransportClients) are reused as much as possible, it is generally
* reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with
* the JVM itself.
*/
public interface TransportClientBootstrap {
/** Performs the bootstrapping operation, throwing an exception on failure. */
void doBootstrap(TransportClient client, Channel channel) throws RuntimeException;
}
TransportClientBootstrap 有两个实现类:EncryptionDisablerBootstrap 和 SaslClientBootstrap。
创建 RPC 客户端 TransportClient
有了 TransportClientFactory, Spark 的各个模块就可以使用它创建 RPC 客户端 TransportClient 了。每个 TransportClient 实例只能和一个远端的 RPC 服务通信,所以 Spark 中的组件如果想要和多个 RPC 服务通信,就需要持有多个 TransportClient 实例。创建 TransportClient 的方法如代码所示(实际为从缓存中获取 TransportClient)。
/**
* Create a {@link TransportClient} connecting to the given remote host / port.
*
* We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer)
* and randomly picks one to use. If no client was previously created in the randomly selected
* spot, this function creates a new client and places it there.
*
* Prior to the creation of a new TransportClient, we will execute all
* {@link TransportClientBootstrap}s that are registered with this factory.
*
* This blocks until a connection is successfully established and fully bootstrapped.
*
* Concurrency: This method is safe to call from multiple threads.
*/
public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
// Use unresolved address here to avoid DNS resolution each time we creates a client.
final InetSocketAddress unresolvedAddress =
InetSocketAddress.createUnresolved(remoteHost, remotePort);
// Create the ClientPool if we don't have it yet.
ClientPool clientPool = connectionPool.get(unresolvedAddress);
if (clientPool == null) {connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
clientPool = connectionPool.get(unresolvedAddress);
}
int clientIndex = rand.nextInt(numConnectionsPerPeer);
TransportClient cachedClient = clientPool.clients[clientIndex];
if (cachedClient != null && cachedClient.isActive()) {
// Make sure that the channel will not timeout by updating the last use time of the
// handler. Then check that the client is still alive, in case it timed out before
// this code was able to update things.
TransportChannelHandler handler = cachedClient.getChannel().pipeline()
.get(TransportChannelHandler.class);
synchronized (handler) {handler.getResponseHandler().updateTimeOfLastRequest();}
if (cachedClient.isActive()) {logger.trace("Returning cached connection to {}: {}",
cachedClient.getSocketAddress(), cachedClient);
return cachedClient;
}
}
// If we reach here, we don't have an existing connection open. Let's create a new one.
// Multiple threads might race here to create new connections. Keep only one of them active.
final long preResolveHost = System.nanoTime();
final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000;
if (hostResolveTimeMs > 2000) {logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
} else {logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
}
synchronized (clientPool.locks[clientIndex]) {cachedClient = clientPool.clients[clientIndex];
if (cachedClient != null) {if (cachedClient.isActive()) {logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient);
return cachedClient;
} else {logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
}
}
clientPool.clients[clientIndex] = createClient(resolvedAddress);
return clientPool.clients[clientIndex];
}
}
从代码得知,创建 TransportClient 的步骤如下。
1)调用 InetSocketAddress 的静态方法 createUnresolved 构建 InetSocketAddress(这种方式创建 InetSocketAddress,可以在缓存中已经有 TransportClient 时避免不必要的域名解析),然后从 connectionPool 中获取与此地址对应的 ClientPool,如果没有,则需要新建 ClientPool,并放入缓存 connectionPool 中。
2)根据 numConnectionsPerPeer 的大小(使用“spark.+ 模块名 +.io.numConnections-PerPeer”属性配置),从 ClientPool 中随机选择一个 TransportClient。
3)如果 ClientPool 的 clients 数组中在随机产生的索引位置不存在 TransportClient 或者 TransportClient 没有激活,则进入第 5 步,否则对此 TransportClient 进行第 4 步的检查。
4)更新 TransportClient 的 channel 中配置的 TransportChannelHandler 的最后一次使用时间,确保 channel 没有超时,然后检查 TransportClient 是否是激活状态,最后返回此 TransportClient 给调用方。
5)由于缓存中没有 TransportClient 可用,于是调用 InetSocketAddress 的构造器创建 InetSocketAddress 对象(直接使用 InetSocketAddress 的构造器创建 InetSocketAddress 会进行域名解析),在这一步骤多个线程可能会产生竞态条件(由于没有同步处理,所以多个线程极有可能同时执行到此处,都发现缓存中没有 TransportClient 可用,于是都使用 InetSocketAddress 的构造器创建 InetSocketAddress)。
6)第 5 步创建 InetSocketAddress 的过程中产生的竞态条件如果不妥善处理,会产生线程安全问题,所以到了 ClientPool 的 locks 数组发挥作用的时候了。按照随机产生的数组索引,locks 数组中的锁对象可以对 clients 数组中的 TransportClient 一对一进行同步。即便之前产生了竞态条件,但是在这一步只能有一个线程进入临界区。在临界区内,先进入的线程调用重载的 createClient 方法创建 TransportClient 对象并放入 ClientPool 的 clients 数组中。当率先进入临界区的线程退出临界区后,其他线程才能进入,此时发现 ClientPool 的 clients 数组中已经存在了 TransportClient 对象,那么将不再创建 TransportClient,而是直接使用它。
下面代码的整个执行过程实际解决了 TransportClient 缓存的使用及 createClient 方法的线程安全问题,并没有涉及创建 TransportClient 的实现。TransportClient 的创建过程在重载的 createClient 方法中实现。
/** Create a completely new {@link TransportClient} to the remote address. */
private TransportClient createClient(InetSocketAddress address) throws IOException {logger.debug("Creating new connection to {}", address);
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(workerGroup)
.channel(socketChannelClass)
// Disable Nagle's Algorithm since we don't want packets to wait
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
.option(ChannelOption.ALLOCATOR, pooledAllocator);
final AtomicReference<TransportClient> clientRef = new AtomicReference<>();
final AtomicReference<Channel> channelRef = new AtomicReference<>();
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {TransportChannelHandler clientHandler = context.initializePipeline(ch);
clientRef.set(clientHandler.getClient());
channelRef.set(ch);
}
});
// Connect to the remote server
long preConnect = System.nanoTime();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
throw new IOException(String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
}
TransportClient client = clientRef.get();
Channel channel = channelRef.get();
assert client != null : "Channel future completed successfully with null client";
// Execute any client bootstraps synchronously before marking the Client as successful.
long preBootstrap = System.nanoTime();
logger.debug("Connection to {} successful, running bootstraps...", address);
try {for (TransportClientBootstrap clientBootstrap : clientBootstraps) {clientBootstrap.doBootstrap(client, channel);
}
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
logger.error("Exception while bootstrapping client after" + bootstrapTimeMs + "ms", e);
client.close();
throw Throwables.propagate(e);
}
long postBootstrap = System.nanoTime();
logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
return client;
}
从代码得知,真正创建 TransportClient 的步骤如下。
1)构建根引导程序 Bootstrap 并对其进行配置。
2)为根引导程序设置管道初始化回调函数,此回调函数将调用 TransportContext 的 initializePipeline 方法初始化 Channel 的 pipeline。
3)使用根引导程序连接远程服务器,当连接成功对管道初始化时会回调初始化回调函数,将 TransportClient 和 Channel 对象分别设置到原子引用 clientRef 与 channelRef 中。
4)给 TransportClient 设置客户端引导程序,即设置 TransportClientFactory 中的 Transport-ClientBootstrap 列表。
5)返回此 TransportClient 对象。
博客基于《Spark 内核设计的艺术:架构设计与实现》一书