|
@@ -0,0 +1,305 @@
|
|
|
+package com.ruoyi.framework.mybatis.interceptor;
|
|
|
+
|
|
|
+import cn.hutool.extra.spring.SpringUtil;
|
|
|
+import com.alibaba.druid.pool.DruidDataSource;
|
|
|
+import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
|
|
|
+import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
|
|
|
+import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
|
|
|
+import com.baomidou.mybatisplus.core.toolkit.StringPool;
|
|
|
+import com.baomidou.mybatisplus.extension.plugins.inner.BaseMultiTableInnerInterceptor;
|
|
|
+import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
|
|
|
+import com.ruoyi.common.enums.DataSourceType;
|
|
|
+import com.ruoyi.framework.datasource.DynamicDataSourceContextHolder;
|
|
|
+import com.ruoyi.framework.mybatis.holder.LogicHolder;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import net.sf.jsqlparser.JSQLParserException;
|
|
|
+import net.sf.jsqlparser.expression.Alias;
|
|
|
+import net.sf.jsqlparser.expression.Expression;
|
|
|
+import net.sf.jsqlparser.expression.LongValue;
|
|
|
+import net.sf.jsqlparser.expression.StringValue;
|
|
|
+import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
|
|
+import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
|
|
+import net.sf.jsqlparser.expression.operators.relational.ItemsList;
|
|
|
+import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;
|
|
|
+import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
|
|
+import net.sf.jsqlparser.schema.Column;
|
|
|
+import net.sf.jsqlparser.schema.Table;
|
|
|
+import net.sf.jsqlparser.statement.delete.Delete;
|
|
|
+import net.sf.jsqlparser.statement.insert.Insert;
|
|
|
+import net.sf.jsqlparser.statement.select.*;
|
|
|
+import net.sf.jsqlparser.statement.update.Update;
|
|
|
+import org.apache.ibatis.executor.Executor;
|
|
|
+import org.apache.ibatis.executor.statement.StatementHandler;
|
|
|
+import org.apache.ibatis.mapping.BoundSql;
|
|
|
+import org.apache.ibatis.mapping.MappedStatement;
|
|
|
+import org.apache.ibatis.mapping.SqlCommandType;
|
|
|
+import org.apache.ibatis.session.ResultHandler;
|
|
|
+import org.apache.ibatis.session.RowBounds;
|
|
|
+import org.springframework.jdbc.core.JdbcTemplate;
|
|
|
+import org.springframework.stereotype.Component;
|
|
|
+
|
|
|
+import java.sql.Connection;
|
|
|
+import java.util.Arrays;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+
|
|
|
+@Slf4j
|
|
|
+@Component
|
|
|
+public class LogicInterceptor extends BaseMultiTableInnerInterceptor implements InnerInterceptor {
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 逻辑删除字段名
|
|
|
+ */
|
|
|
+ private static final String delFlagName = "del_flag";
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 未删除值
|
|
|
+ */
|
|
|
+ private static final Expression notDelValue = new LongValue(0);
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 已删除值
|
|
|
+ */
|
|
|
+ private static final Expression delValue = new LongValue(1);
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 包含逻辑删除字段的库表集合
|
|
|
+ * key: 数据源名称
|
|
|
+ * value: 包含逻辑删除字段的表名集合
|
|
|
+ */
|
|
|
+ private static final Map<String, List<String>> notIncludeLogicIdTableNameMap = new HashMap<>();
|
|
|
+
|
|
|
+ public LogicInterceptor() {
|
|
|
+ // 获取主库数据源
|
|
|
+ DruidDataSource dataSource = SpringUtil.getBean("masterDataSource");
|
|
|
+ putTableMap(dataSource, DataSourceType.MASTER);
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 添加不存在租户字段的表map
|
|
|
+ *
|
|
|
+ * @param dataSource 数据源
|
|
|
+ * @param dataSourceType 数据源类型
|
|
|
+ */
|
|
|
+ private void putTableMap(DruidDataSource dataSource, DataSourceType dataSourceType) {
|
|
|
+ // 获取链接url
|
|
|
+ String url = dataSource.getUrl();
|
|
|
+ // 获取数据库名
|
|
|
+ String dbName = url.split("/")[3].split("\\?")[0];
|
|
|
+ // 查询不包含租户字段的表名
|
|
|
+ String sql = "SELECT DISTINCT table_name FROM information_schema.COLUMNS WHERE table_schema = ? AND column_name != ?";
|
|
|
+ // 获取jdbc
|
|
|
+ JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
|
|
|
+ List<String> notIncludeTenantIdTableNameList = jdbcTemplate.queryForList(sql, String.class, dbName, delFlagName);
|
|
|
+ notIncludeLogicIdTableNameMap.put(dataSourceType.name(), notIncludeTenantIdTableNameList);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
|
|
|
+ PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
|
|
|
+ mpBs.sql(parserSingle(mpBs.sql(), null));
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
|
|
|
+ PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
|
|
|
+ MappedStatement ms = mpSh.mappedStatement();
|
|
|
+ SqlCommandType sct = ms.getSqlCommandType();
|
|
|
+ // 处理新增编辑
|
|
|
+ if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE) {
|
|
|
+ PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
|
|
|
+ mpBs.sql(parserMulti(mpBs.sql(), null));
|
|
|
+ }
|
|
|
+ // 处理删除
|
|
|
+ if (sct == SqlCommandType.DELETE) {
|
|
|
+ PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
|
|
|
+ String sql = mpBs.sql();
|
|
|
+ try {
|
|
|
+ Delete delete = (Delete) CCJSqlParserUtil.parse(sql);
|
|
|
+ Table table = delete.getTable();
|
|
|
+ if (isSkip(table)) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ sql = "UPDATE " + table.getName() + " SET " + delFlagName + " = " + delValue;
|
|
|
+
|
|
|
+ Expression where = delete.getWhere();
|
|
|
+ if (where != null) {
|
|
|
+ sql += " WHERE " + where;
|
|
|
+ }
|
|
|
+ mpBs.sql(sql);
|
|
|
+ } catch (JSQLParserException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ protected void processSelect(Select select, int index, String sql, Object obj) {
|
|
|
+ final String whereSegment = (String) obj;
|
|
|
+ processSelectBody(select.getSelectBody(), whereSegment);
|
|
|
+ List<WithItem> withItemsList = select.getWithItemsList();
|
|
|
+ if (!CollectionUtils.isEmpty(withItemsList)) {
|
|
|
+ withItemsList.forEach(withItem -> processSelectBody(withItem, whereSegment));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ protected void processInsert(Insert insert, int index, String sql, Object obj) {
|
|
|
+ if (isSkip(insert.getTable())) {
|
|
|
+ // 过滤退出执行
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ List<Column> columns = insert.getColumns();
|
|
|
+ if (CollectionUtils.isEmpty(columns)) {
|
|
|
+ // 针对不给列名的insert 不处理
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if (columns.stream().map(Column::getColumnName).anyMatch(i -> i.equalsIgnoreCase(delFlagName))) {
|
|
|
+ // 针对已给出逻辑删除列的insert 不处理
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ columns.add(new Column(delFlagName));
|
|
|
+
|
|
|
+ // fixed gitee pulls/141 duplicate update
|
|
|
+ List<Expression> duplicateUpdateColumns = insert.getDuplicateUpdateExpressionList();
|
|
|
+ if (CollectionUtils.isNotEmpty(duplicateUpdateColumns)) {
|
|
|
+ EqualsTo equalsTo = new EqualsTo();
|
|
|
+ equalsTo.setLeftExpression(new StringValue(delFlagName));
|
|
|
+ equalsTo.setRightExpression(notDelValue);
|
|
|
+ duplicateUpdateColumns.add(equalsTo);
|
|
|
+ }
|
|
|
+
|
|
|
+ Select select = insert.getSelect();
|
|
|
+ if (select != null) {
|
|
|
+ this.processInsertSelect(select.getSelectBody(), (String) obj);
|
|
|
+ } else if (insert.getItemsList(ItemsList.class) != null) {
|
|
|
+ ItemsList itemsList = insert.getItemsList(ItemsList.class);
|
|
|
+ if (itemsList instanceof MultiExpressionList) {
|
|
|
+ ((MultiExpressionList) itemsList).getExpressionLists().forEach(el -> el.getExpressions().add(notDelValue));
|
|
|
+ } else {
|
|
|
+ ((ExpressionList) itemsList).getExpressions().add(notDelValue);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * update 语句处理
|
|
|
+ */
|
|
|
+ @Override
|
|
|
+ protected void processUpdate(Update update, int index, String sql, Object obj) {
|
|
|
+ final Table table = update.getTable();
|
|
|
+ if (isSkip(table)) {
|
|
|
+ // 过滤退出执行
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ update.setWhere(this.andExpression(table, update.getWhere(), (String) obj));
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Expression buildTableExpression(Table table, Expression where, String whereSegment) {
|
|
|
+ if (isSkip(table)) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ return new EqualsTo(new Column(getAliasColumn(table)), notDelValue);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 处理 insert into select
|
|
|
+ * <p>
|
|
|
+ * 进入这里表示需要 insert 的表启用了多租户,则 select 的表都启动了
|
|
|
+ *
|
|
|
+ * @param selectBody SelectBody
|
|
|
+ */
|
|
|
+ protected void processInsertSelect(SelectBody selectBody, final String whereSegment) {
|
|
|
+ PlainSelect plainSelect = (PlainSelect) selectBody;
|
|
|
+ FromItem fromItem = plainSelect.getFromItem();
|
|
|
+ if (fromItem instanceof Table) {
|
|
|
+ processPlainSelect(plainSelect, whereSegment);
|
|
|
+ appendSelectItem(plainSelect.getSelectItems());
|
|
|
+ } else if (fromItem instanceof SubSelect) {
|
|
|
+ SubSelect subSelect = (SubSelect) fromItem;
|
|
|
+ appendSelectItem(plainSelect.getSelectItems());
|
|
|
+ processInsertSelect(subSelect.getSelectBody(), whereSegment);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 追加 SelectItem
|
|
|
+ *
|
|
|
+ * @param selectItems SelectItem
|
|
|
+ */
|
|
|
+ protected void appendSelectItem(List<SelectItem> selectItems) {
|
|
|
+ if (CollectionUtils.isEmpty(selectItems)) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if (selectItems.size() == 1) {
|
|
|
+ SelectItem item = selectItems.get(0);
|
|
|
+ if (item instanceof AllColumns || item instanceof AllTableColumns) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ selectItems.add(new SelectExpressionItem(new Column(delFlagName)));
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 逻辑删除字段别名设置
|
|
|
+ * <p>
|
|
|
+ * del_flag 或 tableAlias.del_flag
|
|
|
+ * </p>
|
|
|
+ *
|
|
|
+ * @param table 表对象
|
|
|
+ * @return 字段
|
|
|
+ */
|
|
|
+ private String getAliasColumn(Table table) {
|
|
|
+ StringBuilder column = new StringBuilder();
|
|
|
+ if (table.getAlias() != null) {
|
|
|
+ column.append(table.getAlias().getName()).append(StringPool.DOT);
|
|
|
+ }
|
|
|
+ column.append(delFlagName);
|
|
|
+ return column.toString();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 是否跳过执行
|
|
|
+ */
|
|
|
+ private boolean isSkip(Table table) {
|
|
|
+
|
|
|
+ LogicHolder logicHolder = LogicHolder.getLogicHolder();
|
|
|
+ String name = table.getName();
|
|
|
+
|
|
|
+ if (logicHolder != null) {
|
|
|
+
|
|
|
+ String[] tableNames = logicHolder.getTableNames();
|
|
|
+ String[] aliases = logicHolder.getAliases();
|
|
|
+
|
|
|
+ if (tableNames.length == 0 && aliases.length == 0) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (Arrays.asList(tableNames).contains(name)) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ Alias alias = table.getAlias();
|
|
|
+ if (alias != null) {
|
|
|
+ if (Arrays.asList(aliases).contains(alias.getName())) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ // 执行sql的数据源名称
|
|
|
+ String dataSourceType = DynamicDataSourceContextHolder.getDataSourceType();
|
|
|
+ // 获取数据源中不包含逻辑删除字段的表名
|
|
|
+ List<String> tableNameList = notIncludeLogicIdTableNameMap.get(dataSourceType);
|
|
|
+ // 如果包涵则跳过拼接逻辑删除
|
|
|
+ return tableNameList.contains(name);
|
|
|
+ }
|
|
|
+
|
|
|
+}
|