欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

自定义mybatis map返回类型

程序员文章站 2022-07-01 08:01:42
...

1.需求背景

设定订单表order,要根据订单类型统计订单数据,大致sql如下:

select order_type , count(1)  as order_num from order group by order_type;

Mybatis无法将以上sql以指定key:order_type;value:order_num存入至map中。

而Mybatis默认返回的List<Map<String, Object>>,是以每个字段name作为key,字段的值作为value,放入至Map<String,Object>的List数组。
因此自定义一种可以指定key、value字段的Mybatis插件将非常有用。

2.自定义Mybatis拦截器实现Mybaits Map返回类型

有关于Mybatis拦截器的介绍请参阅Mybatis拦截器

由Mybatis源码可知,返回结果集是ResultSetHandler接口的handleResultSets方法实现的,源码如下:

public interface ResultSetHandler {  
      <E> List<E> handleResultSets(Statement stmt) throws SQLException; 
      void handleOutputParameters(CallableStatement cs) throws SQLException;
}

当我们需要返回Map时,只需要对此方法进行拦截,重新组装返回结果数据;当需要拦截时,执行invocation.proceed()。首先我们定义MapParam.java类,用去标识结果集需要拦截,此外在该类中指定返回结果集Map的key和value名称,以及value返回类型。MapParam代码如下:

public class MapParam extends HashMap {   
      // key名称 
      public static final String KEY_FIELD = "keyField";    
      // value名称 
      public static final String VALUE_FIELD = "valueField";    
      // value值类型 
      public static final String VALUE_CLASS = "valueClass";    
      public MapParam(){    }    
      public MapParam(String keyField, String valueField, String valueClass){        
                this.put(KEY_FIELD, keyField);        
                this.put(VALUE_FIELD, valueField);        
                this.put(VALUE_CLASS, valueClass);    }   
      // value值类型枚举类 
      public enum ValueClass {        
                INTEGER("integer"),        
                BIG_DECIMAL("bigDecimal");        
                private String code;       
                public String getCode() {            
                         return code;        
                }        
                ValueClass(String code){            
                        this.code = code;       
                }    
       }
}

通过类MapParam,我们可以定义key和value的字段值,还可以定义value的值类型;

接下来,我们定义MapInterceptor.java,通过对返回结果集方法handleResultSets拦截,返回需要指定的Map数据。少说废话,直接上代码:

@Intercepts(@Signature(method="handleResultSets", type=ResultSetHandler.class, args={Statement.class}))
public class MapInterceptor implements Interceptor {    
         //日志    private static final Logger logger = LoggerFactory.getLogger(MapInterceptor.class);    
         @Override    
        public Object intercept(Invocation invocation) throws Throwable {        
              // 获取代理目标对象        
             Object target = invocation.getTarget();        
             if (target instanceof DefaultResultSetHandler) { 
             DefaultResultSetHandler resultSetHandler = (DefaultResultSetHandler) target;           
             // 利用反射获取参数对象            
             ParameterHandler parameterHandler = reflect(resultSetHandler);
            Object parameterObj = parameterHandler.getParameterObject();
            // 参数对象为MapParam进入处理逻辑            
            if (parameterObj instanceof MapParam) {
                    MapParam mapParam = (MapParam) parameterObj;               
                   // 获取当前statement                
                   Statement stmt = (Statement) invocation.getArgs()[0];               
                  // 根据maoParam返回处理结果               
                   return handleResultSet(stmt.getResultSet(), mapParam);            
             }       
           }       
 return invocation.proceed();   
 }
@Override
public Object plugin(Object target) {    
          return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}

private Object handleResultSet(ResultSet resultSet, MapParam mapParam){   
          if (null != resultSet){        
          // 获取key field name        
          String keyFieldName = (String)mapParam.get(MapParam.KEY_FIELD);
          // 获取value field name        
          String valueFieldName = (String)mapParam.get(MapParam.VALUE_FIELD);        
         // 值类型        
         String valueClass = (String) mapParam.get(MapParam.VALUE_CLASS);
          List<Object> resultList = new ArrayList<Object>();        
          Map<Object, Object> map = new HashMap<Object, Object>();        
         try {            
              while (resultSet.next()) {                
              Object key = resultSet.getObject(keyFieldName);                
               Object value ;               
               // 根据值类型转换值               
                if (StringUtils.equals(valueClass, MapParam.ValueClass.INTEGER.getCode())) {                    
                     value = resultSet.getInt(valueFieldName);                
               } else if(StringUtils.equals(valueClass, MapParam.ValueClass.BIG_DECIMAL.getCode())) {                    
                     value = resultSet.getBigDecimal(valueFieldName);                
               } else {                    
                    value = resultSet.getObject(valueFieldName);               
             }               
                   map.put(key, value);            
            }        
           } catch (SQLException e) {            
                   logger.error("map interceptor转换异常,{}", e.getMessage());        
           } finally {            
                   // 关闭result set            
                  closeResultSet(resultSet);        
         }       
          resultList.add(map);        
          return resultList;    
}    
return  null;
}

private void closeResultSet(ResultSet resultSet) {    
         try {       
                 if (resultSet != null) {            
                        resultSet.close();        
                 }    
             } catch (SQLException e) {        
                  logger.error("关闭 result set异常,{}", e.getMessage());   
                 }
}

private ParameterHandler reflect(DefaultResultSetHandler resultSetHandler){
      Field field = ReflectionUtils.findField(DefaultResultSetHandler.class, "parameterHandler");    
      field.setAccessible(true);    
      Object value = null;    
      try {       
               value = field.get(resultSetHandler);    
       } catch (Exception e) {        
              logger.error("默认返回结果集反射参数对象异常,{}", e.getMessage()); 
      }    
     return (ParameterHandler)value;
}

@Intercepts(@Signature(method="handleResultSets", type=ResultSetHandler.class, args={Statement.class}))注解代码含义,@Intercepts用于表示该类为拦截器,@Signature用于标识需要拦截的方法名、返回类型及方法参数值。

进入ResultSetHandler的handleResultSets方法,都会进入此拦截器,当需要的请求参数为类MapParam时,执行处理,否则执行invocation.proceed()
代码写的很详细,再次无需再重复讲述。

3.具体实现

注册拦截器:

<bean id="sqlSessionFactory" 
        <property name="plugins">        
              <array>            
                    <bean class="com.test.common.interceptor.MapInterceptor"/>        
              </array>    
          </property>
</bean>

dao层代码:

MapParam params = new MapParam("orderType","orderNum",MapParam.ValueClass.INTEGER.getCode());
Map<String,Integer> find(MapParam params);

xml代码:

<select id="find" resultType="map" parameterType="MapParam">  
    select order_type as orderType , count(1)  as orderNum from order group by order_type;
</select>  

打印结果格式:
“mobileOrder”:“100”
“partsOrder”:“200”
“normalOrder”:“500”