欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

简仿Mybatis实现原理的小工具

程序员文章站 2022-06-17 16:06:30
...

直接上实现代码: 

package com.wzc.daohelper;

import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.sql.*;
import java.util.*;

/**
 * @author WANGZIC
 */
public class WZCDaoHelper {

    private static final String START_PARAM ="#{";

    private static final String END_PARAM ="}";

    private static final String SELECT="SELECT";

    private static final String XML_TYPE=".xml";

    private String driverName = null;

    private String url = null;

    private String username = null;

    private String password= null;

    private String scanPath= null;

    private Map<String,String> sqlCollection =new HashMap<String,String>();

    public WZCDaoHelper() {
        try {
            Properties properties = new Properties();
            File configFile =new File(WZCDaoHelper.class.getResource("com/wzc/daohelper/config.properties").getPath().substring(1));
            if(!configFile.exists()){
                throw new RuntimeException("未找到配置文件");
            }
            InputStream is = new FileInputStream(configFile);
            properties.load(is);
            //读取属性
            this.driverName = properties.getProperty("jdbc.driverName");
            this.url = properties.getProperty("jdbc.url");
            this.username = properties.getProperty("jdbc.username");
            this.password = properties.getProperty("jdbc.password");
            this.scanPath =this.getClass().getResource("/").getPath()+properties.getProperty("dao.scanPath");
            scanDaoXML(scanPath);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public WZCDaoHelper(String configFilePath) {
        try {
            Properties properties = new Properties();
            File configFile =new File(configFilePath);
            if(!configFile.exists()){
                throw new RuntimeException("未找到配置文件");
            }
            InputStream is = new FileInputStream(configFile);
            properties.load(is);
            //读取属性
            this.driverName = properties.getProperty("jdbc.driverName");
            this.url = properties.getProperty("jdbc.url");
            this.username = properties.getProperty("jdbc.username");
            this.password = properties.getProperty("jdbc.password");
            this.scanPath = this.getClass().getResource("/").getPath()+properties.getProperty("dao.scanPath");
            scanDaoXML(scanPath);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void scanDaoXML(String path){
        Map<String,String> sqlmap =new HashMap<String,String>();
        File dir = new File(path);
        if(dir.exists()){
            File[] files = dir.listFiles();
            if(files!=null){
                for(File f:files){
                    if(f.isDirectory()){
                        scanDaoXML(f.getPath());
                    }else{
                        if(f.getName().endsWith(XML_TYPE)){
                            try {
                                DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
                                DocumentBuilder dbd = dbf.newDocumentBuilder();
                                Document doc = dbd.parse(f);
                                Element root = doc.getDocumentElement();
                                String namespace = root.getAttribute("namespace");
                                NodeList children =root.getElementsByTagName("sql");
                                for(int i=0;i<children.getLength();i++){
                                    Element element= (Element) children.item(i);
                                    sqlmap.put(namespace+"."+element.getAttribute("id"),element.getTextContent().trim());
                                }
                            } catch (Exception e) {
                                e.printStackTrace();
                            }
                        }
                    }
                }
            }
        }
        this.sqlCollection= sqlmap;
    }

    public List<Map<String,Object>> executeSQL(String sqlId, Map<String,Object> params){
        String sql =sqlCollection.get(sqlId);
        if(sql==null){
            throw new RuntimeException("未找到id为:"+sqlId+"的sql语句");
        }
        List<Map<String,Object>> resultList =new ArrayList<>();
        ResultSet rs = null;
        PreparedStatement ps=null;
        Statement st =null;
        Connection conn  = getConn();
        if(conn==null){
            throw  new RuntimeException("未获取到数据库连接,请检查配置信息");
        }
        try {
            if(sql.contains(START_PARAM)){
                StringBuilder builder = new StringBuilder();
                List<Object> paramValList =new ArrayList<>();
                while (sql.contains(START_PARAM)){
                    int startIndex = sql.indexOf(START_PARAM);
                    int endIndex = sql.indexOf(END_PARAM);
                    String paramKey =sql.substring(startIndex+2,endIndex);
                    if(params.containsKey(paramKey)){
                        paramValList.add(params.get(paramKey));
                    }else{
                        throw new RuntimeException("缺少"+paramKey+"字段的值");
                    }
                    sql =builder.append(sql.substring(0,startIndex)).append(" ? ").append(sql.substring(endIndex+1,sql.length())).toString();
                    builder.delete(0,builder.length());
                }
                ps= conn.prepareStatement(sql);
                int num =1;
                for(Object paramVal:paramValList){
                    if(paramVal instanceof java.lang.String){
                        ps.setString(num, (String) paramVal);
                    }else if(paramVal instanceof java.lang.Integer){
                        ps.setInt(num, (Integer)  paramVal);
                    }else{
                        ps.setObject(num,paramVal);
                    }
                    num++;
                }
                if(sql.toUpperCase().startsWith(SELECT)){
                    rs =ps.executeQuery();
                    resultList =parseResultSetToListMap(rs);
                }else{
                    int res = ps.executeUpdate();
                    Map<String,Object> resmap =new HashMap<String,Object>();
                    resmap.put("update",res);
                    resultList.add(resmap);
                }
            }else{
                st= conn.createStatement();
                if(sql.toUpperCase().startsWith(SELECT)){
                    rs =st.executeQuery(sql);
                    resultList =parseResultSetToListMap(rs);
                }else{
                    int res = st.executeUpdate(sql);
                    Map<String,Object> resmap =new HashMap<String,Object>();
                    resmap.put("update",res);
                    resultList.add(resmap);
                }

            }
        } catch (SQLException e) {
            e.printStackTrace();
        }finally {
            close(conn,st,ps,rs);
        }
        return resultList;

    }

    /**
     * 将查询结果转换为List<Map<String,Object>>类型
     * @param rs
     * @return List<Map<String,Object>>
     */
    private static List<Map<String,Object>> parseResultSetToListMap(ResultSet rs) {
        List<Map<String,Object>> results=new ArrayList<Map<String,Object>>();
        try {
            if(rs==null){
                return results;
            }
            ResultSetMetaData rsmd = rs.getMetaData();
            int colLen=rsmd.getColumnCount();
            while(rs.next()){
                Map<String,Object> map = new HashMap<String,Object>();
                for(int i = 0;i < colLen;i++){
                    String colName = rsmd.getColumnName(i+1);
                    map.put(colName,rs.getObject(colName));
                }
                results.add(map);
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return results;
    }

    /**
     * 将查询结果转换为指定类型的对象
     * @param rs
     * @param cls
     * @param <T>
     * @return
     */
    private static <T> List<T> parseResultSetToBean(ResultSet rs,Class<T> cls) {
        try {
            List<T> list=new ArrayList<T>();
            //将查询的所有数据转换为对象添加到集合
            while(rs.next()){
                //实例化对象
                T obj=cls.newInstance();
                //获取类中所有的属性
                Field[] arrf=cls.getDeclaredFields();
                //遍历属性
                for(Field f:arrf){
                    //设置忽略访问校验
                    f.setAccessible(true);
                    //为属性设置内容
                    f.set(obj, rs.getObject(f.getName()));
                }
                //添加到集合
                list.add(obj);
            }
            return list;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    private Connection getConn(){
        Connection conn = null;
        try {
            Class.forName(this.driverName);
            //2. 建立连接 参数一: 协议 + 访问的数据库 , 参数二: 用户名 , 参数三: 密码。
            conn = DriverManager.getConnection(url, username, password);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return conn;
    }
    private static void close(Connection conn,Statement st,PreparedStatement ps,ResultSet rs){
        closeConn(conn);
        closePs(ps);
        closeSt(st);
        closeRs(rs);
    }
    private static void closeRs(ResultSet rs){
        try {
            if(rs != null){
                rs.close();
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }finally{
            rs = null;
        }
    }
    private static void closeSt(Statement st){
        try {
            if(st != null){
                st.close();
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }finally{
            st = null;
        }
    }
    private static void closePs(PreparedStatement ps){
        try {
            if(ps != null){
                ps.close();
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }finally{
            ps = null;
        }
    }
    private static void closeConn(Connection conn){
        try {
            if(conn != null){
                conn.close();
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }finally{
            conn = null;
        }
    }

    public static void main(String[] args) {
        System.out.println("欢迎使用WZCDaoHelper");
    }
}

 配置文件:

jdbc.driverName=com.mysql.jdbc.Driver
jdbc.username=root
jdbc.password=root
jdbc.url=jdbc:mysql://127.0.0.1:3306/mytest
dao.scanPath=

 

目录结构 

简仿Mybatis实现原理的小工具

SQL.xml结构:

<?xml version="1.0" encoding="utf-8"?>
<DaoSQL namespace="user">
    <sql id="insertUser">
        insert into tb_user (id,username,password,birthdate,signature) values(#{id},#{username},#{password},#{birthdate},#{signature})
    </sql>
    <sql id="updatePassword">
        update tb_user set password=#{password} where id=#{id}
    </sql>
    <sql id="deleteById">
        DELETE FROM tb_user WHERE id=#{id}
    </sql>
    <sql id="selectAll">
        select * from tb_user
    </sql>
    <sql id="selectById">
        select * from tb_user where id=#{id}
    </sql>
</DaoSQL>

 

 

相关标签: java mybatis dao