package com.shinemo.insurance.common.config;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.util.Map;
import com.shinemo.insurance.common.annotation.TableShard;
import com.shinemo.insurance.common.util.HashUtil;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.ReflectorFactory;
import org.apache.ibatis.reflection.SystemMetaObject;
@Intercepts({ @Signature(type = StatementHandler.
class
, method =
"prepare"
, args = { Connection.
class
,
Integer.
class
}) })
public
class
TableShardInterceptor
implements
Interceptor {
private
static
final
ReflectorFactory DEFAULT_REFLECTOR_FACTORY =
new
DefaultReflectorFactory();
@Override
public
Object intercept(Invocation invocation) throws Throwable {
MetaObject metaObject = getMetaObject(invocation);
BoundSql boundSql = (BoundSql) metaObject.getValue(
"delegate.boundSql"
);
MappedStatement mappedStatement = (MappedStatement) metaObject
.getValue(
"delegate.mappedStatement"
);
Method method = invocation.getMethod();
TableShard tableShard = getTableShard(method, mappedStatement);
if
(tableShard == null) {
return
invocation.proceed();
}
String value = tableShard.value();
boolean fieldFlag = tableShard.fieldFlag();
if
(fieldFlag) {
Object parameterObject = boundSql.getParameterObject();
if
(parameterObject
instanceof
MapperMethod.ParamMap) {
MapperMethod.ParamMap parameterMap = (MapperMethod.ParamMap) parameterObject;
Object valueObject = parameterMap.get(value);
if
(valueObject == null) {
throw
new
RuntimeException(String.format(
"入参字段%s无匹配"
, value));
}
replaceSql(tableShard, valueObject, metaObject, boundSql);
}
else
{
if
(isBaseType(parameterObject)) {
throw
new
RuntimeException(
"单参数非法,请使用@Param注解"
);
}
if
(parameterObject
instanceof
Map) {
Map<String, Object> parameterMap = (Map<String, Object>) parameterObject;
Object valueObject = parameterMap.get(value);
replaceSql(tableShard, valueObject, metaObject, boundSql);
}
else
{
Class<?> parameterObjectClass = parameterObject.getClass();
Field declaredField = parameterObjectClass.getDeclaredField(value);
declaredField.setAccessible(true);
Object valueObject = declaredField.get(parameterObject);
replaceSql(tableShard, valueObject, metaObject, boundSql);
}
}
}
else
{
replaceSql(tableShard, value, metaObject, boundSql);
}
return
invocation.proceed();
}
@Override
public
Object plugin(Object target) {
if
(target
instanceof
StatementHandler) {
return
Plugin.wrap(target, this);
}
else
{
return
target;
}
}
private
boolean isBaseType(Object object) {
if
(object.getClass().isPrimitive() || object
instanceof
String || object
instanceof
Integer
|| object
instanceof
Double || object
instanceof
Float || object
instanceof
Long
|| object
instanceof
Boolean || object
instanceof
Byte || object
instanceof
Short) {
return
true;
}
else
{
return
false;
}
}
private
void replaceSql(TableShard tableShard, Object value, MetaObject metaObject,
BoundSql boundSql) {
String tableNamePrefix = tableShard.tableNamePrefix();
String shardTableName = generateTableName(tableNamePrefix, (String) value);
String sql = boundSql.getSql();
metaObject.setValue(
"delegate.boundSql.sql"
,
sql.replaceAll(tableNamePrefix, shardTableName));
}
private
String generateTableName(String tableNamePrefix, String value) {
int prime = 1024;
int rotatingHash = HashUtil.rotatingHash(value, prime);
return
tableNamePrefix + rotatingHash;
}
private
MetaObject getMetaObject(Invocation invocation) {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = MetaObject.forObject(statementHandler,
SystemMetaObject.DEFAULT_OBJECT_FACTORY,
SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);
return
metaObject;
}
private
TableShard getTableShard(Method method,
MappedStatement mappedStatement) throws ClassNotFoundException {
String id = mappedStatement.getId();
final
String className = id.substring(0, id.lastIndexOf(
"."
));
TableShard tableShard = null;
tableShard = method.getAnnotation(TableShard.
class
);
if
(tableShard == null) {
tableShard = Class.forName(className).getAnnotation(TableShard.
class
);
}
return
tableShard;
}
}