关于shardingsphere:聊聊ShardingSphere是怎么进行sql重写的

5次阅读

共计 24260 个字符,预计需要花费 61 分钟才能阅读完成。

本文次要钻研一下 ShardingSphere 进行 sql 重写的原理

prepareStatement

org/apache/shardingsphere/driver/jdbc/core/connection/ShardingSphereConnection.java

public final class ShardingSphereConnection extends AbstractConnectionAdapter {

    @Override
    public PreparedStatement prepareStatement(final String sql) throws SQLException {return new ShardingSpherePreparedStatement(this, sql);
    }

    //......
}    

ShardingSphereConnection 的 prepareStatement 创立的是 ShardingSpherePreparedStatement

ShardingSpherePreparedStatement

org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java

public final class ShardingSpherePreparedStatement extends AbstractPreparedStatementAdapter {
    
    @Getter
    private final ShardingSphereConnection connection;

    public ShardingSpherePreparedStatement(final ShardingSphereConnection connection, final String sql) throws SQLException {this(connection, sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, ResultSet.HOLD_CURSORS_OVER_COMMIT, false, null);
    }

    private ShardingSpherePreparedStatement(final ShardingSphereConnection connection, final String sql,
                                            final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability, final boolean returnGeneratedKeys,
                                            final String[] columns) throws SQLException {if (Strings.isNullOrEmpty(sql)) {throw new EmptySQLException().toSQLException();}
        this.connection = connection;
        metaDataContexts = connection.getContextManager().getMetaDataContexts();
        SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
        hintValueContext = sqlParserRule.isSqlCommentParseEnabled() ? new HintValueContext() : SQLHintUtils.extractHint(sql).orElseGet(HintValueContext::new);
        this.sql = sqlParserRule.isSqlCommentParseEnabled() ? sql : SQLHintUtils.removeHint(sql);
        statements = new ArrayList<>();
        parameterSets = new ArrayList<>();
        SQLParserEngine sqlParserEngine = sqlParserRule.getSQLParserEngine(DatabaseTypeEngine.getTrunkDatabaseTypeName(metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType()));
        sqlStatement = sqlParserEngine.parse(this.sql, true);
        sqlStatementContext = SQLStatementContextFactory.newInstance(metaDataContexts.getMetaData(), sqlStatement, connection.getDatabaseName());
        parameterMetaData = new ShardingSphereParameterMetaData(sqlStatement);
        statementOption = returnGeneratedKeys ? new StatementOption(true, columns) : new StatementOption(resultSetType, resultSetConcurrency, resultSetHoldability);
        executor = new DriverExecutor(connection);
        JDBCExecutor jdbcExecutor = new JDBCExecutor(connection.getContextManager().getExecutorEngine(), connection.getDatabaseConnectionManager().getConnectionContext());
        batchPreparedStatementExecutor = new BatchPreparedStatementExecutor(metaDataContexts, jdbcExecutor, connection.getDatabaseName());
        kernelProcessor = new KernelProcessor();
        statementsCacheable = isStatementsCacheable(metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getRuleMetaData());
        trafficRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(TrafficRule.class);
        selectContainsEnhancedTable = sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isContainsEnhancedTable();
        statementManager = new StatementManager();}

    //......
}    

ShardingSpherePreparedStatement 继承了 AbstractPreparedStatementAdapter,其结构器次要是通过 SQLParserEngine 解析 sql 失去 SQLStatement,创立 DriverExecutor、BatchPreparedStatementExecutor、KernelProcessor、StatementManager;这里即便 useServerPrepStmts=true,也不会触发 mysql server 的 prepare 操作

executeUpdate

    public int executeUpdate() throws SQLException {
        try {if (statementsCacheable && !statements.isEmpty()) {resetParameters();
                return statements.iterator().next().executeUpdate();}
            clearPrevious();
            QueryContext queryContext = createQueryContext();
            trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
            if (null != trafficInstanceId) {JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
                return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeUpdate());
            }
            executionContext = createExecutionContext(queryContext);
            if (hasRawExecutionRule()) {Collection<ExecuteResult> executeResults = executor.getRawExecutor().execute(createRawExecutionGroupContext(), executionContext.getQueryContext(), new RawSQLExecutorCallback());
                return accumulate(executeResults);
            }
            return isNeedImplicitCommitTransaction(connection, executionContext) ? executeUpdateWithImplicitCommitTransaction() : useDriverToExecuteUpdate();
            // CHECKSTYLE:OFF
        } catch (final RuntimeException ex) {
            // CHECKSTYLE:ON
            handleExceptionInTransaction(connection, metaDataContexts);
            throw SQLExceptionTransformEngine.toSQLException(ex, metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType().getType());
        } finally {clearBatch();
        }
    }

    private void clearPrevious() {statements.clear();
        parameterSets.clear();
        generatedValues.clear();}

    private ExecutionContext createExecutionContext(final QueryContext queryContext) {ShardingSphereRuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
        ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName());
        SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
        ExecutionContext result = kernelProcessor.generateExecutionContext(queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
        findGeneratedKey(result).ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
        return result;
    }

这里 executeUpdate 会先执行 clearPrevious 办法,清空 statements、parameterSets、generatedValues,而后 createExecutionContext,这里有一步是 kernelProcessor.generateExecutionContext

KernelProcessor

generateExecutionContext

shardingsphere-infra-context-5.4.0-sources.jar!/org/apache/shardingsphere/infra/connection/kernel/KernelProcessor.java

    public ExecutionContext generateExecutionContext(final QueryContext queryContext, final ShardingSphereDatabase database, final ShardingSphereRuleMetaData globalRuleMetaData,
                                                     final ConfigurationProperties props, final ConnectionContext connectionContext) {RouteContext routeContext = route(queryContext, database, globalRuleMetaData, props, connectionContext);
        SQLRewriteResult rewriteResult = rewrite(queryContext, database, globalRuleMetaData, props, routeContext, connectionContext);
        ExecutionContext result = createExecutionContext(queryContext, database, routeContext, rewriteResult);
        logSQL(queryContext, props, result);
        return result;
    }

KernelProcessor 的 generateExecutionContext 办法先创立 routeContext,而后执行 rewrite,最初执行 createExecutionContext

rewrite

    private SQLRewriteResult rewrite(final QueryContext queryContext, final ShardingSphereDatabase database, final ShardingSphereRuleMetaData globalRuleMetaData,
                                     final ConfigurationProperties props, final RouteContext routeContext, final ConnectionContext connectionContext) {SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry(database, globalRuleMetaData, props);
        return sqlRewriteEntry.rewrite(queryContext.getSql(), queryContext.getParameters(), queryContext.getSqlStatementContext(), routeContext, connectionContext, queryContext.getHintValueContext());
    }

rewrite 次要是通过 SQLRewriteEntry 的 rewrite 办法进行的

SQLRewriteEntry

shardingsphere-infra-rewrite-5.4.0-sources.jar!/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java

    /**
     * Rewrite.
     * 
     * @param sql SQL
     * @param params SQL parameters
     * @param sqlStatementContext SQL statement context
     * @param routeContext route context
     * @param connectionContext connection context
     * @param hintValueContext hint value context
     * 
     * @return route unit and SQL rewrite result map
     */
    public SQLRewriteResult rewrite(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
                                    final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, params, sqlStatementContext, routeContext, connectionContext, hintValueContext);
        SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
        DatabaseType protocolType = database.getProtocolType();
        Map<String, DatabaseType> storageTypes = database.getResourceMetaData().getStorageTypes();
        return routeContext.getRouteUnits().isEmpty()
                ? new GenericSQLRewriteEngine(rule, protocolType, storageTypes).rewrite(sqlRewriteContext)
                : new RouteSQLRewriteEngine(rule, protocolType, storageTypes).rewrite(sqlRewriteContext, routeContext);
    }

    private SQLRewriteContext createSQLRewriteContext(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
                                                      final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {SQLRewriteContext result = new SQLRewriteContext(database.getName(), database.getSchemas(), sqlStatementContext, sql, params, connectionContext, hintValueContext);
        decorate(decorators, result, routeContext, hintValueContext);
        result.generateSQLTokens();
        return result;
    }

    private void decorate(final Map<ShardingSphereRule, SQLRewriteContextDecorator> decorators, final SQLRewriteContext sqlRewriteContext,
                          final RouteContext routeContext, final HintValueContext hintValueContext) {if (hintValueContext.isSkipSQLRewrite()) {return;}
        for (Entry<ShardingSphereRule, SQLRewriteContextDecorator> entry : decorators.entrySet()) {entry.getValue().decorate(entry.getKey(), props, sqlRewriteContext, routeContext);
        }
    }

SQLRewriteEntry 的 rewrite 办法,先通过 createSQLRewriteContext 来创立 SQLRewriteContext,这里通过 decorate 办法遍历 decorators,挨个执行 SQLRewriteContextDecorator 的 decorate 办法;最初通过 GenericSQLRewriteEngine 或者 RouteSQLRewriteEngine 进行 rewrite

SQLRewriteContextDecorator

org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextDecorator.java

@SingletonSPI
public interface SQLRewriteContextDecorator<T extends ShardingSphereRule> extends OrderedSPI<T> {
    
    /**
     * Decorate SQL rewrite context.
     *
     * @param rule rule
     * @param props ShardingSphere properties
     * @param sqlRewriteContext SQL rewrite context to be decorated
     * @param routeContext route context
     */
    void decorate(T rule, ConfigurationProperties props, SQLRewriteContext sqlRewriteContext, RouteContext routeContext);
}

SQLRewriteContextDecorator 定义了 decorate 办法,它有诸如 ShardingSQLRewriteContextDecorator、EncryptSQLRewriteContextDecorator 的实现类

EncryptSQLRewriteContextDecorator

org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java

/**
 * SQL rewrite context decorator for encrypt.
 */
public final class EncryptSQLRewriteContextDecorator implements SQLRewriteContextDecorator<EncryptRule> {
    
    @Override
    public void decorate(final EncryptRule encryptRule, final ConfigurationProperties props, final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {SQLStatementContext sqlStatementContext = sqlRewriteContext.getSqlStatementContext();
        if (!containsEncryptTable(encryptRule, sqlStatementContext)) {return;}
        Collection<EncryptCondition> encryptConditions = createEncryptConditions(encryptRule, sqlRewriteContext);
        if (!sqlRewriteContext.getParameters().isEmpty()) {
            Collection<ParameterRewriter> parameterRewriters = new EncryptParameterRewriterBuilder(encryptRule,
                    sqlRewriteContext.getDatabaseName(), sqlRewriteContext.getSchemas(), sqlStatementContext, encryptConditions).getParameterRewriters();
            rewriteParameters(sqlRewriteContext, parameterRewriters);
        }
        Collection<SQLTokenGenerator> sqlTokenGenerators = new EncryptTokenGenerateBuilder(encryptRule,
                sqlStatementContext, encryptConditions, sqlRewriteContext.getDatabaseName()).getSQLTokenGenerators();
        sqlRewriteContext.addSQLTokenGenerators(sqlTokenGenerators);
    }
    
    private Collection<EncryptCondition> createEncryptConditions(final EncryptRule encryptRule, final SQLRewriteContext sqlRewriteContext) {SQLStatementContext sqlStatementContext = sqlRewriteContext.getSqlStatementContext();
        if (!(sqlStatementContext instanceof WhereAvailable)) {return Collections.emptyList();
        }
        Collection<WhereSegment> whereSegments = ((WhereAvailable) sqlStatementContext).getWhereSegments();
        Collection<ColumnSegment> columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments();
        return new EncryptConditionEngine(encryptRule, sqlRewriteContext.getSchemas())
                .createEncryptConditions(whereSegments, columnSegments, sqlStatementContext, sqlRewriteContext.getDatabaseName());
    }
    
    private boolean containsEncryptTable(final EncryptRule encryptRule, final SQLStatementContext sqlStatementContext) {for (String each : sqlStatementContext.getTablesContext().getTableNames()) {if (encryptRule.findEncryptTable(each).isPresent()) {return true;}
        }
        return false;
    }
    
    private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection<ParameterRewriter> parameterRewriters) {for (ParameterRewriter each : parameterRewriters) {each.rewrite(sqlRewriteContext.getParameterBuilder(), sqlRewriteContext.getSqlStatementContext(), sqlRewriteContext.getParameters());
        }
    }
    
    @Override
    public int getOrder() {return EncryptOrder.ORDER;}
    
    @Override
    public Class<EncryptRule> getTypeClass() {return EncryptRule.class;}
}

rewriteParameters 是通过 ParameterRewriter 进行 rewrite,次要是批改 ParameterBuilder;而具体 sql 语句的批改则通过 sqlTokenGenerators 进行

SQLToken

@RequiredArgsConstructor
@Getter
public abstract class SQLToken implements Comparable<SQLToken> {
    
    private final int startIndex;
    
    @Override
    public final int compareTo(final SQLToken sqlToken) {return startIndex - sqlToken.startIndex;}
}

SQLToken 它有诸如 InsertValuesToken、SubstitutableColumnNameToken、InsertColumnsToken 之类的实现类

RouteSQLRewriteEngine

    /**
     * Rewrite SQL and parameters.
     *
     * @param sqlRewriteContext SQL rewrite context
     * @param routeContext route context
     * @return SQL rewrite result
     */
    public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits = new LinkedHashMap<>(routeContext.getRouteUnits().size(), 1F);
        for (Entry<String, Collection<RouteUnit>> entry : aggregateRouteUnitGroups(routeContext.getRouteUnits()).entrySet()) {Collection<RouteUnit> routeUnits = entry.getValue();
            if (isNeedAggregateRewrite(sqlRewriteContext.getSqlStatementContext(), routeUnits)) {sqlRewriteUnits.put(routeUnits.iterator().next(), createSQLRewriteUnit(sqlRewriteContext, routeContext, routeUnits));
            } else {addSQLRewriteUnits(sqlRewriteUnits, sqlRewriteContext, routeContext, routeUnits);
            }
        }
        return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext().getSqlStatement(), sqlRewriteUnits));
    }

    private void addSQLRewriteUnits(final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits, final SQLRewriteContext sqlRewriteContext,
                                    final RouteContext routeContext, final Collection<RouteUnit> routeUnits) {for (RouteUnit each : routeUnits) {sqlRewriteUnits.put(each, new SQLRewriteUnit(new RouteSQLBuilder(sqlRewriteContext, each).toSQL(), getParameters(sqlRewriteContext.getParameterBuilder(), routeContext, each)));
        }
    }

    private Map<RouteUnit, SQLRewriteUnit> translate(final SQLStatement sqlStatement, final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {Map<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<>(sqlRewriteUnits.size(), 1F);
        for (Entry<RouteUnit, SQLRewriteUnit> entry : sqlRewriteUnits.entrySet()) {DatabaseType storageType = storageTypes.get(entry.getKey().getDataSourceMapper().getActualName());
            String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatement, protocolType, storageType);
            SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(sql, entry.getValue().getParameters());
            result.put(entry.getKey(), sqlRewriteUnit);
        }
        return result;
    }

addSQLRewriteUnits 是往 sqlRewriteUnits 增加 SQLRewriteUnit,最初 translate 办法构建 SQLRewriteUnit;SQLRewriteUnit 蕴含了更改之后的 sql 以及对应改变后的参数

useDriverToExecuteUpdate

org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java

    private int useDriverToExecuteUpdate() throws SQLException {ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext();
        cacheStatements(executionGroupContext.getInputGroups());
        return executor.getRegularExecutor().executeUpdate(executionGroupContext,
                executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteUpdateCallback());
    }

    private ExecutionGroupContext<JDBCExecutionUnit> createExecutionGroupContext() throws SQLException {DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = createDriverExecutionPrepareEngine();
        return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(connection.getDatabaseName()));
    } 

    private void cacheStatements(final Collection<ExecutionGroup<JDBCExecutionUnit>> executionGroups) throws SQLException {for (ExecutionGroup<JDBCExecutionUnit> each : executionGroups) {each.getInputs().forEach(eachInput -> {statements.add((PreparedStatement) eachInput.getStorageResource());
                parameterSets.add(eachInput.getExecutionUnit().getSqlUnit().getParameters());
            });
        }
        replay();}

    private void replay() throws SQLException {replaySetParameter();
        for (Statement each : statements) {getMethodInvocationRecorder().replay(each);
        }
    }

    private void replaySetParameter() throws SQLException {for (int i = 0; i < statements.size(); i++) {replaySetParameter(statements.get(i), parameterSets.get(i));
        }
    }

    protected final void replaySetParameter(final PreparedStatement preparedStatement, final List<Object> params) throws SQLException {setParameterMethodInvocations.clear();
        addParameters(params);
        for (PreparedStatementInvocationReplayer each : setParameterMethodInvocations) {each.replayOn(preparedStatement);
        }
    }

    private void addParameters(final List<Object> params) {
        int i = 0;
        for (Object each : params) {
            int index = ++i;
            setParameterMethodInvocations.add(preparedStatement -> preparedStatement.setObject(index, each));
        }
    }

useDriverToExecuteUpdate 办法会执行 createExecutionGroupContext(会执行 prepare 办法 ),cacheStatements 这里次要是把 eachInput.getStorageResource() 真正的 PrepareStatement 赋值到 ShardingSpherePreparedStatement 的 statements 变量中,把 eachInput.getExecutionUnit().getSqlUnit().getParameters()赋值到 parameterSets,而后执行 replay 办法通过 PreparedStatementInvocationReplayer 把批改后的变量 replay 到真正的 PrepareStatement
该办法委托给 executor.getRegularExecutor().executeUpdate,最初一个参数为 callback,即 createExecuteUpdateCallback

DriverExecutionPrepareEngine.prepare

org/apache/shardingsphere/infra/executor/sql/prepare/AbstractExecutionPrepareEngine.java

    public final ExecutionGroupContext<T> prepare(final RouteContext routeContext, final Collection<ExecutionUnit> executionUnits,
                                                  final ExecutionGroupReportContext reportContext) throws SQLException {return prepare(routeContext, Collections.emptyMap(), executionUnits, reportContext);
    }

    public final ExecutionGroupContext<T> prepare(final RouteContext routeContext, final Map<String, Integer> connectionOffsets, final Collection<ExecutionUnit> executionUnits,
                                                  final ExecutionGroupReportContext reportContext) throws SQLException {Collection<ExecutionGroup<T>> result = new LinkedList<>();
        for (Entry<String, List<SQLUnit>> entry : aggregateSQLUnitGroups(executionUnits).entrySet()) {String dataSourceName = entry.getKey();
            List<SQLUnit> sqlUnits = entry.getValue();
            List<List<SQLUnit>> sqlUnitGroups = group(sqlUnits);
            ConnectionMode connectionMode = maxConnectionsSizePerQuery < sqlUnits.size() ? ConnectionMode.CONNECTION_STRICTLY : ConnectionMode.MEMORY_STRICTLY;
            result.addAll(group(dataSourceName, connectionOffsets.getOrDefault(dataSourceName, 0), sqlUnitGroups, connectionMode));
        }
        return decorate(routeContext, result, reportContext);
    }

    protected List<ExecutionGroup<T>> group(final String dataSourceName, final int connectionOffset, final List<List<SQLUnit>> sqlUnitGroups, final ConnectionMode connectionMode) throws SQLException {List<ExecutionGroup<T>> result = new LinkedList<>();
        List<C> connections = databaseConnectionManager.getConnections(dataSourceName, connectionOffset, sqlUnitGroups.size(), connectionMode);
        int count = 0;
        for (List<SQLUnit> each : sqlUnitGroups) {result.add(createExecutionGroup(dataSourceName, each, connections.get(count++), connectionMode));
        }
        return result;
    }

    private ExecutionGroup<T> createExecutionGroup(final String dataSourceName, final List<SQLUnit> sqlUnits, final C connection, final ConnectionMode connectionMode) throws SQLException {List<T> result = new LinkedList<>();
        for (SQLUnit each : sqlUnits) {result.add((T) sqlExecutionUnitBuilder.build(new ExecutionUnit(dataSourceName, each), statementManager, connection, connectionMode, option, databaseTypes.get(dataSourceName)));
        }
        return new ExecutionGroup<>(result);
    }

group 办法调用遍历 SQLUnit 执行 createExecutionGroup,而后者则执行 sqlExecutionUnitBuilder.build;这里 databaseConnectionManager.getConnections 获取的 connection 是通过真正 driver 获取的 connection(com.mysql.jdbc.Driver)

PreparedStatementExecutionUnitBuilder

org/apache/shardingsphere/infra/executor/sql/prepare/driver/jdbc/builder/PreparedStatementExecutionUnitBuilder.java

    public JDBCExecutionUnit build(final ExecutionUnit executionUnit, final ExecutorJDBCStatementManager statementManager,
                                   final Connection connection, final ConnectionMode connectionMode, final StatementOption option, final DatabaseType databaseType) throws SQLException {
        PreparedStatement preparedStatement = createPreparedStatement(executionUnit, statementManager, connection, connectionMode, option, databaseType);
        return new JDBCExecutionUnit(executionUnit, connectionMode, preparedStatement);
    }

    private PreparedStatement createPreparedStatement(final ExecutionUnit executionUnit, final ExecutorJDBCStatementManager statementManager, final Connection connection,
                                                      final ConnectionMode connectionMode, final StatementOption option, final DatabaseType databaseType) throws SQLException {return (PreparedStatement) statementManager.createStorageResource(executionUnit, connection, connectionMode, option, databaseType);
    }

PreparedStatementExecutionUnitBuilder 的 build 办法这里才真正创立 PreparedStatement

StatementManager

org/apache/shardingsphere/driver/jdbc/core/statement/StatementManager.java

    public Statement createStorageResource(final ExecutionUnit executionUnit, final Connection connection, final ConnectionMode connectionMode, final StatementOption option,
                                           final DatabaseType databaseType) throws SQLException {Statement result = cachedStatements.get(new CacheKey(executionUnit, connectionMode));
        if (null == result || result.isClosed() || result.getConnection().isClosed()) {String sql = executionUnit.getSqlUnit().getSql();
            if (option.isReturnGeneratedKeys()) {result = null == option.getColumns() || 0 == option.getColumns().length
                        ? connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)
                        : connection.prepareStatement(sql, option.getColumns());
            } else {result = connection.prepareStatement(sql, option.getResultSetType(), option.getResultSetConcurrency(), option.getResultSetHoldability());
            }
            cachedStatements.put(new CacheKey(executionUnit, connectionMode), result);
        }
        return result;
    }

createStorageResource 则是通过 connection.prepareStatement 来创立真正的 PrepareStatement,而此时传入的 sql 也是通过重写之后的 sql

createExecuteUpdateCallback

org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java

    private JDBCExecutorCallback<Integer> createExecuteUpdateCallback() {boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
        return new JDBCExecutorCallback<Integer>(metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType(),
                metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getResourceMetaData().getStorageTypes(), sqlStatement, isExceptionThrown) {
            
            @Override
            protected Integer executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {return ((PreparedStatement) statement).executeUpdate();}
            
            @Override
            protected Optional<Integer> getSaneResult(final SQLStatement sqlStatement, final SQLException ex) {return Optional.empty();
            }
        };
    }

createExecuteUpdateCallback 创立的 JDBCExecutorCallback,其 executeSQL 办法则是通过 ((PreparedStatement) statement).executeUpdate() 来执行,即委托给了真正的 PreparedStatement

小结

  • ShardingSphereConnection 的 prepareStatement 创立的是 ShardingSpherePreparedStatement,它在 ShardingSpherePreparedStatement 的 executeUpdate 的时候进行 sql 重写,而后 prepare,最初执行的时候是通过 JDBCExecutorCallback,其 executeSQL 办法则是通过 ((PreparedStatement) statement).executeUpdate() 来执行,即委托给了真正的 PreparedStatement
  • rewriteParameters 是通过 ParameterRewriter 进行 rewrite,次要是批改 ParameterBuilder;而具体 sql 语句的批改则通过 sqlTokenGenerators 进行
  • PreparedStatementExecutionUnitBuilder 的 build 办法这里才真正创立 PreparedStatement:它通过 StatementManager.createStorageResource 则是通过 connection.prepareStatement 来创立真正的 PrepareStatement,而此时传入的 sql 也是通过重写之后的 sql
  • useDriverToExecuteUpdate 办法会执行 createExecutionGroupContext(会执行 prepare 办法 ),cacheStatements 这里次要是把 eachInput.getStorageResource() 真正的 PrepareStatement 赋值到 ShardingSpherePreparedStatement 的 statements 变量中,把 eachInput.getExecutionUnit().getSqlUnit().getParameters()赋值到 parameterSets,而后执行 replay 办法通过 PreparedStatementInvocationReplayer 把批改后的变量 replay 到真正的 PrepareStatement

    ShardingSpherePreparedStatement 实现了 java.sql.PreparedStatement 接口,其 sql 属性是用户传入的 sql,即未通过重写的 sql,而理论 execute 的时候,会触发 sql 重写 (包含重写 sql 语句及参数),最初会通过 connection.prepareStatement(传入重写之后的 sql) 来创立真正的 PrepareStatement,而后有一步 replay 操作,把重写后的参数作用到真正的 PrepareStatement,最初通过 ((PreparedStatement) statement).executeUpdate() 来触发执行
    至此咱们能够失去 sql 重写的一个基本思路:通过实现 java.sql.PreparedStatement 接口假装一个 PreparedStatement 类,其创立和 set 参数先内存缓存起来,之后在 execute 的时候进行 sql 重写,创立真正的 PreparedStatement,replay 参数,执行 execute 办法

正文完
 0