spring boot+mybatis plus进行sql拦截实现权限过滤

spring boot+mybatis plus进行sql拦截实现权限过滤
programApe📃 关联文档
权限数据过滤
定义一个注解用于开启权限过滤功能
这次没参与后台业务部分开发并不清楚哪些业务需要该功能,所以没有默认进行开启,将主动权交于业务开发人员手中
import java.lang.annotation.*;
import static java.lang.annotation.ElementType.*;
/**
* 企业id数据过滤
*
* @author ChenQi
*/
@Target({METHOD, ANNOTATION_TYPE, TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DataScope {
/**
* 当进行过滤时主表中代表企业id的字段
*/
String unitField() default "ent_id";
/**
* 是否进行数据过滤
*/
boolean filterData() default true;
}
定义一个对象储存每次请求时相关接口过滤的需使用的数据
import lombok.AllArgsConstructor;
import lombok.Data;
import java.util.Set;
/**
* 类 DataScopeParam
* </p>
*
* @author ChenQi
* @since 2022/10/20 17:37
*/
@Data
@AllArgsConstructor
public class DataScopeParam {
/**
* 企业筛选字段名称
*/
private String unitField;
/**
* 企业数据范围
*/
private Set<Long> entIdList;
/**
* 是否进行拦截
*/
private boolean filterField;
}
使用阿里开源的TransmittableThreadLocal
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>transmittable-thread-local</artifactId>
</dependency>
创建拦截器修改sql使其能够将权限过滤的字段代入
import cn.hutool.core.collection.CollUtil;
import com.alibaba.ttl.TransmittableThreadLocal;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.lyc.admin.oauth.service.SysUser;
import com.lyc.admin.oauth.utils.SecurityUtils;
import com.lyc.common.base.annotation.DataScope;
import com.lyc.common.base.constant.CommonConstants;
import com.lyc.common.base.utils.CurrentEntIdSearchContextHolder;
import com.lyc.common.base.vo.EntierVO;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SetOperationList;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.After;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;
import java.io.StringReader;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.util.Collection;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.stream.Collectors;
/**
* 类 DataPermissionInterceptor
* </p>
*
* @author ChenQi
* @since 2022/10/20 14:50
*/
@Aspect
@Slf4j
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
@Component
public class UnitDataPermissionInterceptor implements Interceptor {
ThreadLocal<DataScopeParam> threadLocal = new TransmittableThreadLocal<>();
/**
* 清空当前线程上次保存的权限信息
*/
@After("dataScopePointCut()")
public void clearThreadLocal() {
threadLocal.remove();
}
/**
* 配置织入点
*/
@Pointcut("@annotation(com.lyc.common.base.annotation.DataScope)")
public void dataScopePointCut() {
}
/**
* @param point JoinPoint
*/
@Before("dataScopePointCut()")
public void doBefore(JoinPoint point) {
// 获得注解
DataScope controllerDataScope = getAnnotationLog(point);
if (controllerDataScope != null && SecurityUtils.getUser() != null) {
// 获取当前用户所具备的企业列表,此处是直接获取用户具备的机构树信息,从机构树中获取对应的企业列表,构建这个机构树是在用户登录时进行操作,此处不做展示
SysUser sysUser = SecurityUtils.getUser();
Set<Long> dataScope = sysUser.getTierVos().stream().map(EntierVO::getUnitIdList).flatMap(Collection::stream).collect(Collectors.toSet());
// 对@DataScope中设置filterData设置为false的注解、管理员用户不进行权限过滤
DataScopeParam dataScopeParam = new DataScopeParam(controllerDataScope.unitField(), dataScope, controllerDataScope.filterData() && !CommonConstants.SUPER_ADMIN.equals(sysUser.getId()));
threadLocal.set(dataScopeParam);
log.debug("当前用户可以查看的企业列表数据 = {}", dataScope);
}
}
/**
* 是否存在注解,如果存在就获取
*/
private DataScope getAnnotationLog(JoinPoint joinPoint) {
org.aspectj.lang.Signature signature = joinPoint.getSignature();
MethodSignature methodSignature = (MethodSignature) signature;
Method method = methodSignature.getMethod();
if (method != null) {
return method.getAnnotation(DataScope.class);
}
return null;
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
DataScopeParam dataScopeParam = threadLocal.get();
// 获取header中的待过滤的企业列表
Set<Long> entIdList = CurrentEntIdSearchContextHolder.getEntIdList();
if (CollUtil.isNotEmpty(entIdList)) {
if (dataScopeParam == null) {
// 如果前端需要查询指定企业列表的数据,则主动创建一个DataScopeParam对象进行数据过滤
dataScopeParam = new DataScopeParam("ent_id", entIdList, true);
} else {
// 获取主动查询的企业列表和用户权限所具备企业列表交集
Set<Long> permissionEntList = dataScopeParam.getEntIdList();
dataScopeParam.setFilterField(true);
dataScopeParam.setEntIdList(entIdList.stream().filter(permissionEntList::contains).collect(Collectors.toSet()));
}
}
// 没有添加注解则不往下执行
if (dataScopeParam == null) {
return invocation.proceed();
}
// 注解配置不过滤数据则不往下执行
if (!dataScopeParam.isFilterField()) {
return invocation.proceed();
}
SysUser sysUser = SecurityUtils.getUser();
// 如果非权限用户则不往下执行
if (sysUser == null) {
return invocation.proceed();
}
StatementHandler statementHandler = PluginUtils.realTarget(invocation.getTarget());
MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
// 先判断是不是SELECT操作 不是直接过滤
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {
return invocation.proceed();
}
BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
// 执行的SQL语句
String originalSql = boundSql.getSql();
// SQL语句的参数
Object parameterObject = boundSql.getParameterObject();
// 需要过滤的数据
String finalSql = this.handleSql(originalSql, dataScopeParam.getEntIdList(), dataScopeParam.getUnitField());
log.warn("数据权限处理过后的SQL: {}", finalSql);
// 装载改写后的sql
metaObject.setValue("delegate.boundSql.sql", finalSql);
return invocation.proceed();
}
/**
* 修改sql
*
* @param originalSql 原始sql
* @param entIdList 需要过滤的企业列表
* @param fieldName 当前主表中字段名称
* @return 修改后的语句
* @throws JSQLParserException
*/
private String handleSql(String originalSql, Set<Long> entIdList, String fieldName) throws JSQLParserException {
CCJSqlParserManager parserManager = new CCJSqlParserManager();
Select select = (Select) parserManager.parse(new StringReader(originalSql));
SelectBody selectBody = select.getSelectBody();
if (selectBody instanceof PlainSelect) {
this.setWhere((PlainSelect) selectBody, entIdList, fieldName);
} else if (selectBody instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectBody;
List<SelectBody> selectBodyList = setOperationList.getSelects();
selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, entIdList, fieldName));
}
return select.toString();
}
/**
* 设置 where 条件 -- 使用CCJSqlParser将原SQL进行解析并改写
*
* @param plainSelect 查询对象
*/
@SneakyThrows(Exception.class)
protected void setWhere(PlainSelect plainSelect, Set<Long> entIdList, String fieldName) {
Table fromItem = (Table) plainSelect.getFromItem();
// 有别名用别名,无别名用表名,防止字段冲突报错
Alias fromItemAlias = fromItem.getAlias();
String mainTableName = fromItemAlias == null ? fromItem.getName() : fromItemAlias.getName();
// 构建子查询 -- 数据权限过滤SQL
String dataPermissionSql = "";
// 当只有一条数据时直接使用ent_id = #{ent_id}
if (entIdList.size() == 1) {
EqualsTo selfEqualsTo = new EqualsTo();
selfEqualsTo.setLeftExpression(new Column(mainTableName + "." + fieldName));
selfEqualsTo.setRightExpression(new LongValue(entIdList.stream().findFirst().get()));
dataPermissionSql = selfEqualsTo.toString();
} else {
dataPermissionSql = mainTableName + "." + fieldName + " in ( " + CollUtil.join(entIdList, StringPool.COMMA) + " )";
}
if (plainSelect.getWhere() == null) {
plainSelect.setWhere(CCJSqlParserUtil.parseCondExpression(dataPermissionSql));
} else {
plainSelect.setWhere(new AndExpression(plainSelect.getWhere(), CCJSqlParserUtil.parseCondExpression(dataPermissionSql)));
}
}
/**
* 生成拦截对象的代理
*
* @param target 目标对象
* @return 代理对象
*/
@Override
public Object plugin(Object target) {
if (target instanceof StatementHandler) {
return Plugin.wrap(target, this);
}
return target;
}
/**
* mybatis配置的属性
*
* @param properties mybatis配置的属性
*/
@Override
public void setProperties(Properties properties) {
}
}
考虑到机构用户会指定查询某企业的数据,将以上权限过滤部分改写使其满足新的需求
添加holder用户储存接口请求中需要过滤的企业列表
import com.alibaba.ttl.TransmittableThreadLocal;
import lombok.experimental.UtilityClass;
import java.util.Set;
/**
* 类 CurrentEntIdSearchContextHolder
* </p>
*
* @author ChenQi
* @since 2022/10/21 10:13
*/
@UtilityClass
public class CurrentEntIdSearchContextHolder {
private final ThreadLocal<Set<Long>> THREAD_LOCAL_ENT_LIST = new TransmittableThreadLocal<>();
/**
* 设置当前header中的企业列表
*
* @param entIdList 需要查询的企业列表
*/
public void setEntIdList(Set<Long> entIdList) {
THREAD_LOCAL_ENT_LIST.set(entIdList);
}
/**
* 获取header中的企业列表
*
* @return 企业列表
*/
public Set<Long> getEntIdList() {
return THREAD_LOCAL_ENT_LIST.get();
}
public void clear() {
THREAD_LOCAL_ENT_LIST.remove();
}
}
添加过滤器获取并储存待过滤的企业列表
import cn.hutool.core.convert.Convert;
import cn.hutool.core.util.StrUtil;
import com.lyc.common.base.constant.CommonConstants;
import com.lyc.common.base.utils.CurrentEntIdSearchContextHolder;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.GenericFilterBean;
import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.HashSet;
import java.util.Set;
/**
* 类 ContextHolderFilter
* </p>
*
* @author ChenQi
* @since 2022/10/21 10:21
*/
@Slf4j
@Component
@Order(Ordered.HIGHEST_PRECEDENCE)
public class EntIdContextHolderFilter extends GenericFilterBean {
@Override
@SneakyThrows
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
Set<Long> entIdList = new HashSet<>();
String entIdListStr = request.getHeader(CommonConstants.ENT_ID_LIST);
if (StrUtil.isNotBlank(entIdListStr)) {
entIdList = Convert.toSet(Long.class, entIdListStr);
log.debug("获取header中的企业列表为:{}", entIdList);
}
CurrentEntIdSearchContextHolder.setEntIdList(entIdList);
filterChain.doFilter(request, response);
CurrentEntIdSearchContextHolder.clear();
}
}
使用方式
添加注解用于过滤数据
同时支持mybatis plus的api和xml中的sql,但是@DataScope中设定的unitField的过滤字段必须在sql的主表中
注解添加在controller中,用于使用mybatis plus api的情况
注解添加在controller或者dao层方法上,用于使用xml中自定义sql的情况
指定查询部分企业列表
在header中添加entIdList
评论
匿名评论隐私政策