RPC客户端工厂TransportClientFactory
TransportClientFactory是创建TransportClient的工厂类。TransportContext的createClientFactory方法可以创建TransportClientFactory的实例
/** * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning * a new Client. Bootstraps will be executed synchronously, and must run successfully in order * to create a Client. */ public TransportClientFactory createClientFactory(List<TransportClientBootstrap> bootstraps) { return new TransportClientFactory(this, bootstraps); } public TransportClientFactory createClientFactory() { return createClientFactory(Lists.<TransportClientBootstrap>newArrayList()); }
可以看到,TransportContext中有两个重载的createClientFactory方法,它们最终在构造TransportClientFactory时都会传递两个参数:TransportContext和TransportClientBootstrap列表。TransportClientFactory构造器的实现如代码所示。
public TransportClientFactory( TransportContext context, List<TransportClientBootstrap> clientBootstraps) { this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); this.connectionPool = new ConcurrentHashMap<>(); this.numConnectionsPerPeer = conf.numConnectionsPerPeer(); this.rand = new Random(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); // TODO: Make thread pool name configurable. this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); this.pooledAllocator = NettyUtils.createPooledByteBufAllocator( conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads()); }
TransportClientFactory构造器中的各个变量如下:
context:参数传递的TransportContext的引用。
conf:指TransportConf,这里通过调用TransportContext的getConf获取。
clientBootstraps:参数传递的TransportClientBootstrap列表。
connectionPool:针对每个Socket地址的连接池ClientPool的缓存。
connectionPool的数据结构较为复杂,为便于读者理解,这里以图来表示connectionPool的数据结构。
numConnectionsPerPeer:从TransportConf获 取 的key为“spark.+模 块 名+.io.num-ConnectionsPerPeer”的属性值。此属性值用于指定对等节点间的连接数。这里的模块名实际为TransportConf的module字段。Spark的很多组件都利用RPC框架构建,它们之间按照模块名区分,例如,RPC模块的key为“spark.rpc.io.numConnectionsPerPeer”。
#TransportConf中的 getConfKey 方法获取参数private String getConfKey(String suffix) { return "spark." + module + "." + suffix; }
rand:对Socket地址对应的连接池ClientPool中缓存的
TransportClient进行随机选择,对每个连接做负载均衡。
ioMode:IO模式,即从TransportConf获取key为“spark.+模块名+.io.mode”的属性值。默认值为NIO, Spark还支持EPOLL。
socketChannelClass:客户端Channel被创建时使用的类,通过ioMode来匹配,默认为NioSocketChannel, Spark还支持EpollEventLoopGroup。
workerGroup:根据Netty的规范,客户端只有worker组,所以此处创建worker-Group。workerGroup的实际类型是NioEventLoopGroup。
pooledAllocator :汇集ByteBuf但对本地线程缓存禁用的分配器。
客户端引导程序TransportClientBootstrap
TransportClientFactory的clientBootstraps属性是TransportClientBootstrap的列表。Transport ClientBootstrap是在TransportClient上执行的客户端引导程序,主要对连接建立时进行一些初始化的准备(例如验证、加密)。TransportClientBootstrap所做的操作往往是昂贵的,好在建立的连接可以重用。TransportClientBootstrap的接口定义如代码清单3-10所示:
import io.netty.channel.Channel;/** * A bootstrap which is executed on a TransportClient before it is returned to the user. * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- * connection basis. * * Since connections (and TransportClients) are reused as much as possible, it is generally * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with * the JVM itself. */public interface TransportClientBootstrap { /** Performs the bootstrapping operation, throwing an exception on failure. */ void doBootstrap(TransportClient client, Channel channel) throws RuntimeException;}
TransportClientBootstrap有两个实现类:EncryptionDisablerBootstrap和SaslClientBootstrap。
创建RPC客户端TransportClient
有了TransportClientFactory, Spark的各个模块就可以使用它创建RPC客户端TransportClient了。每个TransportClient实例只能和一个远端的RPC服务通信,所以Spark中的组件如果想要和多个RPC服务通信,就需要持有多个TransportClient实例。创建TransportClient的方法如代码所示(实际为从缓存中获取TransportClient)。
/** * Create a {@link TransportClient} connecting to the given remote host / port. * * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer) * and randomly picks one to use. If no client was previously created in the randomly selected * spot, this function creates a new client and places it there. * * Prior to the creation of a new TransportClient, we will execute all * {@link TransportClientBootstrap}s that are registered with this factory. * * This blocks until a connection is successfully established and fully bootstrapped. * * Concurrency: This method is safe to call from multiple threads. */ public TransportClient createClient(String remoteHost, int remotePort) throws IOException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. // Use unresolved address here to avoid DNS resolution each time we creates a client. final InetSocketAddress unresolvedAddress = InetSocketAddress.createUnresolved(remoteHost, remotePort); // Create the ClientPool if we don't have it yet. ClientPool clientPool = connectionPool.get(unresolvedAddress); if (clientPool == null) { connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer)); clientPool = connectionPool.get(unresolvedAddress); } int clientIndex = rand.nextInt(numConnectionsPerPeer); TransportClient cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null && cachedClient.isActive()) { // Make sure that the channel will not timeout by updating the last use time of the // handler. Then check that the client is still alive, in case it timed out before // this code was able to update things. TransportChannelHandler handler = cachedClient.getChannel().pipeline() .get(TransportChannelHandler.class); synchronized (handler) { handler.getResponseHandler().updateTimeOfLastRequest(); } if (cachedClient.isActive()) { logger.trace("Returning cached connection to {}: {}", cachedClient.getSocketAddress(), cachedClient); return cachedClient; } } // If we reach here, we don't have an existing connection open. Let's create a new one. // Multiple threads might race here to create new connections. Keep only one of them active. final long preResolveHost = System.nanoTime(); final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort); final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000; if (hostResolveTimeMs > 2000) { logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); } else { logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); } synchronized (clientPool.locks[clientIndex]) { cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null) { if (cachedClient.isActive()) { logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient); return cachedClient; } else { logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress); } } clientPool.clients[clientIndex] = createClient(resolvedAddress); return clientPool.clients[clientIndex]; } }
从代码得知,创建TransportClient的步骤如下。
1)调用InetSocketAddress的静态方法createUnresolved构建InetSocketAddress(这种方式创建InetSocketAddress,可以在缓存中已经有TransportClient时避免不必要的域名解析),然后从connectionPool中获取与此地址对应的ClientPool,如果没有,则需要新建ClientPool,并放入缓存connectionPool中。
2)根据numConnectionsPerPeer的大小(使用“spark.+模块名+.io.numConnections-PerPeer”属性配置),从ClientPool中随机选择一个TransportClient。
3)如果ClientPool的clients数组中在随机产生的索引位置不存在TransportClient或者TransportClient没有激活,则进入第5步,否则对此TransportClient进行第4步的检查。
4)更新TransportClient的channel中配置的TransportChannelHandler的最后一次使用时间,确保channel没有超时,然后检查TransportClient是否是激活状态,最后返回此TransportClient给调用方。
5)由于缓存中没有TransportClient可用,于是调用InetSocketAddress的构造器创建InetSocketAddress对象(直接使用InetSocketAddress的构造器创建InetSocketAddress会进行域名解析),在这一步骤多个线程可能会产生竞态条件(由于没有同步处理,所以多个线程极有可能同时执行到此处,都发现缓存中没有TransportClient可用,于是都使用InetSocketAddress的构造器创建InetSocketAddress)。
6)第5步创建InetSocketAddress的过程中产生的竞态条件如果不妥善处理,会产生线程安全问题,所以到了ClientPool的locks数组发挥作用的时候了。按照随机产生的数组索引,locks数组中的锁对象可以对clients数组中的TransportClient一对一进行同步。即便之前产生了竞态条件,但是在这一步只能有一个线程进入临界区。在临界区内,先进入的线程调用重载的createClient方法创建TransportClient对象并放入ClientPool的clients数组中。当率先进入临界区的线程退出临界区后,其他线程才能进入,此时发现ClientPool的clients数组中已经存在了TransportClient对象,那么将不再创建TransportClient,而是直接使用它。
下面代码的整个执行过程实际解决了TransportClient缓存的使用及createClient方法的线程安全问题,并没有涉及创建TransportClient的实现。TransportClient的创建过程在重载的createClient方法中实现。
/** Create a completely new {@link TransportClient} to the remote address. */ private TransportClient createClient(InetSocketAddress address) throws IOException { logger.debug("Creating new connection to {}", address); Bootstrap bootstrap = new Bootstrap(); bootstrap.group(workerGroup) .channel(socketChannelClass) // Disable Nagle's Algorithm since we don't want packets to wait .option(ChannelOption.TCP_NODELAY, true) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()) .option(ChannelOption.ALLOCATOR, pooledAllocator); final AtomicReference<TransportClient> clientRef = new AtomicReference<>(); final AtomicReference<Channel> channelRef = new AtomicReference<>(); bootstrap.handler(new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); clientRef.set(clientHandler.getClient()); channelRef.set(ch); } }); // Connect to the remote server long preConnect = System.nanoTime(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { throw new IOException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } else if (cf.cause() != null) { throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); } TransportClient client = clientRef.get(); Channel channel = channelRef.get(); assert client != null : "Channel future completed successfully with null client"; // Execute any client bootstraps synchronously before marking the Client as successful. long preBootstrap = System.nanoTime(); logger.debug("Connection to {} successful, running bootstraps...", address); try { for (TransportClientBootstrap clientBootstrap : clientBootstraps) { clientBootstrap.doBootstrap(client, channel); } } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000; logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e); client.close(); throw Throwables.propagate(e); } long postBootstrap = System.nanoTime(); logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000); return client; }
从代码得知,真正创建TransportClient的步骤如下。
1)构建根引导程序Bootstrap并对其进行配置。
2)为根引导程序设置管道初始化回调函数,此回调函数将调用TransportContext的initializePipeline方法初始化Channel的pipeline。
3)使用根引导程序连接远程服务器,当连接成功对管道初始化时会回调初始化回调函数,将TransportClient和Channel对象分别设置到原子引用clientRef与channelRef中。
4)给TransportClient设置客户端引导程序,即设置TransportClientFactory中的Transport-ClientBootstrap列表。
5)返回此TransportClient对象。
博客基于《Spark内核设计的艺术:架构设计与实现》一书