自己动手实现mybatis动态sql的方法
发现要坚持写博客真的是一件很困难的事情,各种原因都会导致顾不上博客。本来打算写自己动手实现orm,看看时间,还是先实现一个动态sql,下次有时间再补上orm完整的实现吧。
用过mybatis的人,估计对动态sql都不陌生,如果没有用过,就当看看热闹吧。我第一次接触mysql是在大四的时候,当时就觉得动态sql这东西很牛,很灵活,一直想搞明白怎么实现的,尽管当时已经能够写ioc,mvc和简单的orm框架(仿mybaits但是没有动态sql部分),但是仍然找不到mybatis核心的动态sql到底在哪实现的,怎么实现的,可能是那些代码太绕根本没法看懂,直到目前,我都没有勇气去看mybatis的动态sql部分,大概是天生对算法有莫名其妙的敬畏吧。
几年前因为想做一个配置平台,想用解析型语言替代java的实现,可以让配置人员在页面上方便的编写少量代码实现复杂的业务逻辑(包括数据库操作)。当时java已经有js解析引擎,但是大多数人都说效率太低,不知道我发什么疯就想到自己实现一个解析语言。不过实现自己的语言也是我一直的梦想,解析语言相对编译型语言入手简单,于是我就果断动手了,写完才知道,其实自己的实现估计还没有当时的js引擎效率高,那时的我真的是很年轻很简单。今天谈到的动态sql实现其实就是受到那时候解析语言的启发。
废话不多说直接开始聊动态sql,请看下面例子,先声明这里的例子并不是一个正确的sql的写法,只是想写一个尽量复杂的嵌套结构,如果把这种复杂的情况实现了,那么简单一点的就更加不在话下了。
delete from pl_pagewidget <if test="widgetcodes != null"> where pagewidgetcode in <foreach collection="widgetcodes" item="item" index="index" open="(" separator="," close=")"> <if test="index == 0"> #{item} </if> <foreach collection="bs" item="b" index="index1" open="(" separator="," close=")"> #{b} </foreach> </foreach> </if> <if test="a != null"> and a = #{a} </if>
要实现解析出上面例子的sql,首先一个难点类似是test属性里的条件怎么判断真假,不过这个难点在struts2中学到的ognl表达式面前就比较小儿科了。不知道有么有朋友遇到过一个比较奇葩的现象,就是有时候明明在mybatis动态sql中写如下表达式,但是当n=0的时候居然是满足条件的也就是test里的值是false,0居然不能满足这个表达式的条件,这里就是ognl库的原因了。没办法它就是这么玩的,当成特殊情况记住就可以了
test="n != null and n !=''"
ognl表达式使用很方便如下
import java.util.hashmap; import java.util.map; import ognl.ognl; public class ognltest { //输出结果:false public static void main(string[] args) throws exception { string con1 = "n != null and n != ''"; map<string,object> root = new hashmap<>(); root.put("n", 0); system.out.println(ognl.getvalue(con1,root)); } }
要实现解析上面例子的sql,第二个难点就是虽然这个sql披上一层xml的皮就是一个标准的sql,如下
<sql> delete from pl_pagewidget <if test="widgetcodes != null"> where pagewidgetcode in <foreach collection="widgetcodes" item="item" index="index" open="(" separator="," close=")"> <if test="index == 0"> #{item} </if> <foreach collection="bs" item="b" index="index1" open="(" separator="," close=")"> #{b} </foreach> </foreach> </if> <if test="a != null"> and a = #{a} </if> </sql>
但是要解析上面的xml和我们平时不一样,这个xml是标签和文本混合的,正常我们开发中应该很少会用到解析这种xml。不过我们常用的解析xml的工具dom4j其实可以很好的解析这种sql,只不过很少可能用到。element类的content()方法就可以返回一个node的集合,再通过遍历这个集合,判断每个node的类型就可以了。解决了这两个重点,只需要加上一点技巧就可以解析这种动态sql了。
我用到的技巧是根据java语法格式得到的启发。比如java中有局部变量和全局变量,不考虑引用传递这种情况,如果全局变量int i = 1;方法里面传入这个全局变量,然后在方法里面修改,在方法里面看到的是改变后的值,但是在方法外面看到的仍然是1。这个现象其实学过java应该都知道。还有就是当方法调用的时候,方法里面可以看到全局变量,也可以看到局部变量,方法调用结束后局部变量会被清空释放(看垃圾搜集器高兴)。介绍了这些直接上代码了
import java.io.stringreader; import java.text.simpledateformat; import java.util.arrays; import java.util.date; import java.util.hashmap; import java.util.list; import java.util.map; import java.util.regex.matcher; import java.util.regex.pattern; import org.apache.commons.collections.maputils; import org.apache.commons.lang.stringutils; import org.dom4j.document; import org.dom4j.element; import org.dom4j.node; import org.dom4j.text; import org.dom4j.io.saxreader; import com.rd.sql.attrs; import com.rd.sql.basenode; import com.rd.sql.nodefactory; public class sqlparser { private map<string,object> currparams = new hashmap<string,object>(); /** delete from pl_pagewidget <if test="widgetcodes != null"> where pagewidgetcode in <foreach collection="widgetcodes" item="item" index="index" open="(" separator="," close=")"> <if test="index == 0"> #{item} </if> <foreach collection="bs" item="b" index="index1" open="(" separator="," close=")"> #{b} </foreach> </foreach> </if> <if test="a != null"> and a = #{a} </if> */ public static void main(string[] args) throws exception { map<string, object> map = new hashmap<string, object>(); map.put("widgetcodes", arrays.aslist("1", "2")); map.put("bs", arrays.aslist("3", "4")); map.put("a", 1); sqlparser parser = new sqlparser(); system.out .println(parser.parser("delete from pl_pagewidget\n" + "\t<if test=\"widgetcodes != null\">\n" + "\t\twhere pagewidgetcode in\n" + "\t\t<foreach collection=\"widgetcodes\" item=\"item\" index=\"index\" open=\"(\" separator=\",\" close=\")\">\n" + "\t\t <if test=\"index == 0\">\n" + "\t\t #{item}\n" + "\t\t </if>\n" + "\t\t <foreach collection=\"bs\" item=\"b\" index=\"index1\" open=\"(\" separator=\",\" close=\")\">\n" + "\t\t\t#{b}\n" + "\t\t </foreach>\n" + "\t\t</foreach>\n" + "\t</if>\n" + "\t<if test=\"a != null\">\n" + "\t\tand a = #{a}\n" + "\t</if>\n", map)); system.out.println(parser.getparams()); } public string parser(string xml, map<string, object> params) throws exception { // xml = "<?xml version=\"1.0\" encoding=\"utf-8\"?>"+xml; //给输入的动态sql套一层xml标签 xml = "<sql>"+xml+"</sql>"; saxreader reader = new saxreader(false); document document = reader.read(new stringreader(xml)); element element = document.getrootelement(); map<string, object> currparams = new hashmap<string, object>(); stringbuilder sb = new stringbuilder(); //开始解析 parserelement(element, currparams, params, sb); return sb.tostring(); } /** * 使用递归解析动态sql * @param ele1 待解析的xml标签 * @param currparams * @param globalparams * @param sb * @throws exception */ private void parserelement(element ele1, map<string, object> currparams, map<string, object> globalparams, stringbuilder sb) throws exception { // 解析一个节点,比如解析到了一个if节点,假如test判断为true这里就返回true tempval val = parseroneelement(currparams, globalparams, ele1, sb); //得到解析的这个节点的抽象节点对象 basenode node = val.getnode(); /** * 实际上这句之上的语句只是解析了xml的标签,并没有解析标签里的内容,这里 * 表示要解析内容之前,如果有前置操作做一点前置操作 */ node.pre(currparams, globalparams, ele1, sb); //判断是否还需要解析节点里的内容,例如if节点test结果为true boolean flag = val.iscontinue(); // 得到该节点下的所有子节点的集合,包含普通文本 list<node> nodes = ele1.content(); if (flag && !nodes.isempty()) { /** * 这里表示要进一步解析节点里的内容了,可以把节点类比成一个方法的外壳 * 里面的内容类比成方法里的具体语句,开始解析节点的内容之前 * 先创建本节点下的局部参数的容器,最方便当然是map */ map<string, object> params = new hashmap<string, object>(); /** * 把外面传进来的局部参数,直接放入容器,由于本例中参数都是常用数据类型 * 不会存在引用类型所以,这里相当于是一个copy,为了不影响外面传入的对象 * 可以类比方法调用传入参数的情况 */ params.putall(currparams); //循环所有子节点 for (int i = 0; i < nodes.size();) { node n = nodes.get(i); //如果节点是普通文本 if (n instanceof text) { string text = ((text) n).getstringvalue(); if (stringutils.isnotempty(text.trim())) { //处理一下文本,如处理#{xx},直接替换${yy}为真实传入的值 sb.append(handtext(text, params,globalparams)); } i++; } else if (n instanceof element) { element e1 = (element) n; // 递归解析xml子元素 parserelement(e1, params, globalparams, sb); // 如果循环标志不为true则解析下一个标签 // 这里表示需要重复解析这个循环标签,则i不变,反之继续处理下一个元素 boolean while_flag = maputils.getboolean(params, attrs.while_flag, false); if (!while_flag || !nodefactory.iswhile(n.getname()) || e1.attributevalue(attrs.index) == null || !e1.attributevalue(attrs.index).equals( params.get(attrs.while_index))) { i++; } } } //节点处理之后做一些啥事 node.after(currparams, globalparams, ele1, sb); // 回收当前作用域参数 params.clear(); params = null; } } /** * 处理文本替换掉#{item}这种参数 * @param str * @param params * @return * @throws exception */ private string handtext(string str, map<string, object> params,map<string, object> globalparams) throws exception { //获取foreach这种标签中用于记录循环的变量 string indexstr = maputils.getstring(params, attrs.while_index); integer index = null; if(stringutils.isnotempty(indexstr)) { index = maputils.getinteger(params, indexstr); } //匹配#{a}这种参数 string reg1 = "(#\\{)(\\w+)(\\})"; //匹配${a}这种参数 string reg2 = "(\\$\\{)(\\w+)(\\})"; pattern p1 = pattern.compile(reg1); matcher m1 = p1.matcher(str); pattern p2 = pattern.compile(reg2); matcher m2 = p2.matcher(str); string whilelist = maputils.getstring(params, attrs.while_list); map<string,object> allparams = getallparams(params, globalparams); while(m1.find()) { string tmpkey = m1.group(2); string key = whilelist == null?tmpkey:(whilelist+"_"+tmpkey); key = index == null?key:(key+index); string rekey = "#{"+key+"}"; //如果在foreach类似的循环里,可能需要将参数#{xx}替换成#{xx_0},#{xx_1} str = str.replace(m1.group(0), rekey); currparams.put(key, allparams.get(tmpkey)); } while(m2.find()) { string tmpkey = m2.group(2); object value = allparams.get(tmpkey); if(value != null) { str = str.replace(m2.group(0), getvalue(value)); } } return str; } private string getvalue(object value) { string result = ""; if(value instanceof date) { simpledateformat sdf = new simpledateformat("yyyy-mm-dd hh:mm:ss"); result = sdf.format((date)value); } else { result = string.valueof(value); } return result; } private map<string, object> getallparams(map<string, object> currparams, map<string, object> globalparams) { map<string,object> allparams = new hashmap<string,object>(); allparams.putall(globalparams); allparams.putall(currparams); return allparams; } // 解析一个xml元素 private tempval parseroneelement(map<string, object> currparams, map<string, object> globalparams, element ele, stringbuilder sb) throws exception { //获取xml标签名 string elename = ele.getname(); //解析一个节点后是否继续,如遇到if这种节点,就需要判断test里是否为空 boolean iscontinue = false; //声明一个抽象节点 basenode node = null; if (stringutils.isnotempty(elename)) { //使用节点工厂根据节点名得到一个节点对象比如是if节点还是foreach节点 node = nodefactory.create(elename); //解析一下这个节点,返回是否还需要解析节点里的内容 iscontinue = node.parse(currparams, globalparams, ele, sb); } return new tempval(iscontinue, ele, node); } public map<string, object> getparams() { return currparams; } /** * 封装一个xml元素被解析后的结果 * @author rongdi */ final static class tempval { private boolean iscontinue; private element ele; private basenode node; public tempval(boolean iscontinue, element ele, basenode node) { this.iscontinue = iscontinue; this.ele = ele; this.node = node; } public boolean iscontinue() { return iscontinue; } public void setcontinue(boolean iscontinue) { this.iscontinue = iscontinue; } public element getele() { return ele; } public void setele(element ele) { this.ele = ele; } public basenode getnode() { return node; } public void setnode(basenode node) { this.node = node; } } }
import org.dom4j.element; import java.util.hashmap; import java.util.map; /** * 抽象节点 * @author rongdi */ public abstract class basenode { public abstract boolean parse(map<string, object> currparams, map<string, object> globalparams, element ele,stringbuilder sb) throws exception; public void pre(map<string, object> currparams,map<string, object> globalparams,element ele,stringbuilder sb) throws exception { } public void after(map<string, object> currparams,map<string, object> globalparams,element ele,stringbuilder sb) throws exception { } protected map<string, object> getallparams(map<string, object> currparams, map<string, object> globalparams) { map<string,object> allparams = new hashmap<string,object>(); allparams.putall(globalparams); allparams.putall(currparams); return allparams; } }
import java.util.map; import ognl.ognl; import org.apache.commons.lang.stringutils; import org.dom4j.element; /** * if节点 * @author rongdi */ public class ifnode extends basenode{ @override public boolean parse(map<string, object> currparams, map<string, object> globalparams, element ele,stringbuilder sb) throws exception { //得到if节点的test属性 string teststr = ele.attributevalue("test"); boolean test = false; try { if(stringutils.isnotempty(teststr)) { //合并全局变量和局部变量 map<string, object> allparams = getallparams(currparams,globalparams); //使用ognl判断true或者false test = (boolean) ognl.getvalue(teststr,allparams); } } catch (exception e) { e.printstacktrace(); throw new exception("判断操作参数"+teststr+"不合法"); } if(ele.content() != null && ele.content().size()==0) { test = true; } return test; } }
import java.util.arraylist; import java.util.hashmap; import java.util.list; import java.util.map; import java.util.set; import ognl.ognl; import org.apache.commons.collections.maputils; import org.apache.commons.lang.stringutils; import org.dom4j.element; /** foreach节点属性如下 collection 需要遍历的集合 item 遍历集合后每个元素存放的变量 index 遍历集合的索引数如0,1,2... separator 遍历后以指定分隔符拼接 open 遍历后拼接开始的符号如 ( close 遍历后拼接结束的符号如 ) */ public class foreachnode extends basenode { @override public boolean parse(map<string, object> currparams, map<string, object> globalparams, element ele, stringbuilder sb) throws exception { string conditionstr = null; string collectionstr = ele.attributevalue("collection"); string itemstr = ele.attributevalue("item"); string index = ele.attributevalue("index"); string separatorstr = ele.attributevalue("separator"); string openstr = ele.attributevalue("open"); string closestr = ele.attributevalue("close"); if(stringutils.isempty(index)) { index = "index"; } if(stringutils.isempty(separatorstr)) { separatorstr = ","; } if(stringutils.isnotempty(openstr)) { currparams.put(attrs.while_open,openstr); } if(stringutils.isnotempty(closestr)) { currparams.put(attrs.while_close,closestr); } if(stringutils.isnotempty(collectionstr)) { currparams.put(attrs.while_list,collectionstr); } currparams.put(attrs.while_separator,separatorstr); if(index != null) { /** * 如果局部变量中存在当前循环变量的值,就表示已经不是第一次进入循环标签了,移除掉开始标记 * 并将局部变量值加1 */ if(currparams.get(index) != null) { currparams.remove(attrs.while_start); currparams.put(index+"_", (integer)currparams.get(index+"_") + 1); } else { //第一次进入循环标签内 currparams.put(attrs.while_start,true); currparams.put(index+"_", 0); } currparams.put(index, (integer)currparams.get(index+"_")); } boolean condition = true; map<string, object> allparams = getallparams(currparams,globalparams); object collection = null; if(stringutils.isnotempty(collectionstr)) { //得到待循环的集合 collection = ognl.getvalue(collectionstr,allparams); //如果集合属性不为空,但是条件为null则默认加上一个边界条件 if(stringutils.isempty(conditionstr)) { //这里只是用集合演示一下,也可以再加上数组,只不过改成.length而已 if(collection instanceof list) { conditionstr = index+"_<"+collectionstr+".size()"; } else if(collection instanceof map){ map map = (map)collection; set set = map.entryset(); list list = new arraylist(set); allparams.put("_list_", list); conditionstr = index+"_<_list_"+".size()"; } } } currparams.remove(attrs.while_end); if(stringutils.isnotempty(conditionstr)) { //计算条件的值 condition = (boolean)ognl.getvalue(conditionstr,allparams); map<string,object> tempmap = new hashmap<>(); tempmap.putall(allparams); tempmap.put(index+"_",(integer)currparams.get(index+"_") + 1); currparams.put(attrs.while_end,!(boolean)ognl.getvalue(conditionstr,tempmap)); } boolean flag = true; currparams.put(attrs.while_index, index); currparams.put(attrs.while_flag, true); if(condition) { try { if(stringutils.isnotempty(itemstr) && stringutils.isnotempty(collectionstr)) { object value = null; int idx = integer.parseint(currparams.get(index+"_").tostring()); if(collection instanceof list) { value = ((list)collection).get(idx); currparams.put(itemstr, value); } else if(collection instanceof map){ map map = (map)collection; set<map.entry<string,object>> set = map.entryset(); list<map.entry<string,object>> list = new arraylist(set); currparams.put(itemstr, list.get(idx).getvalue()); currparams.put(index, list.get(idx).getkey()); } } } catch (exception e) { throw new exception("从集合或者映射取值"+currparams.get(index)+"错误"+e.getmessage()); } } else { flag = false; destroyvars(currparams, index, itemstr); } return flag; } /** * 如果是第一次进入循环标签,则拼上open的内容 */ @override public void pre(map<string, object> currparams, map<string, object> globalparams, element ele, stringbuilder sb) throws exception { super.pre(currparams, globalparams, ele, sb); boolean start = maputils.getboolean(currparams,attrs.while_start,false); if(start) { string open = maputils.getstring(currparams,attrs.while_open); sb.append(open); } } /** * 如果是最后进入循环标签,则最后拼上close的内容 */ @override public void after(map<string, object> currparams, map<string, object> globalparams, element ele, stringbuilder sb) throws exception { super.after(currparams, globalparams, ele, sb); boolean end = maputils.getboolean(currparams,attrs.while_end,false); string separator = maputils.getstring(currparams,attrs.while_separator); if(!end && stringutils.isnotempty(separator)) { sb.append(separator); } if(end) { string close = maputils.getstring(currparams,attrs.while_close); if(sb.tostring().endswith(separator)) { sb.deletecharat(sb.length() - 1); } sb.append(close); } } //释放临时变量 private void destroyvars(map<string, object> currparams, string index,string varstr) { currparams.remove(attrs.while_index); currparams.remove(attrs.while_flag); currparams.remove(attrs.while_separator); currparams.remove(attrs.while_start); currparams.remove(attrs.while_end); currparams.remove(attrs.while_list); } } import org.dom4j.element; import java.util.map; public class sqlnode extends basenode{ @override public boolean parse(map<string, object> currparams, map<string, object> globalparams, element ele,stringbuilder sb) throws exception { return true; } } import java.util.arrays; import java.util.list; import java.util.map; import java.util.concurrent.concurrenthashmap; /** * 节点工厂 */ public class nodefactory { private static map<string,basenode> nodemap = new concurrenthashmap<string,basenode>(); private final static list<string> whilelist = arrays.aslist("foreach"); static { nodemap.put("if", new ifnode()); nodemap.put("sql", new sqlnode()); nodemap.put("foreach", new foreachnode()); } public static boolean iswhile(string elementname) { return whilelist.contains(elementname); } public static void addnode(string nodename,basenode node) { nodemap.put(nodename, node); } public static basenode create(string nodename) { return nodemap.get(nodename); } } /** * 各种标记 * @author rongdi */ public class attrs { public final static string transactional = "transactional"; public final static string while_start = "while-start"; public final static string while_end = "while-end"; public final static string while_open = "while-open"; public final static string while_close = "while-close"; public final static string while_separator = "while-separator"; public final static string while_index = "while-index"; public final static string while_flag = "while-flag"; public final static string while_list = "while-list"; public final static string when_flag = "when-flag"; public static final string process_var = "process-var"; public final static string result_flag = "result-flag"; public final static string return_flag = "return-flag"; public final static string console_var= "console-var"; public final static string do = "do"; public final static string index = "index"; public final static string condition = "condition"; public final static string name= "name"; public final static string value= "value"; public static final string type = "type"; public static final string format = "format"; public static final string if = "if"; public static final string else = "else"; public final static string file= "file"; public static final string date = "date"; public static final string now = "now"; public static final string decimal = "decimal"; public static final string id = "id"; public static final string params = "params"; public static final string target = "target"; public static final string single = "single"; public static final string paging = "paging"; public static final string desc = "desc"; public static final string break = "break"; public static final string continue = "continue"; public static final string collection = "collection"; public static final string var = "var"; public static final string executor = "executor-1"; public static final string rollback_flag = "rollback-flag"; public static final string service = "service"; public static final string ref = "ref"; public static final string bizs = "bizs"; public static final string titles = "titles"; public static final string columns = "columns"; public static final string curruser = "curruser"; public static final string currperm = "currperm"; public static final string task_executor = "taskexecutor"; public static final string delimiter = "delimiter"; public static final string opername = "opername"; } currparams.remove(varstr); currparams.remove(index); currparams.remove(index+"_"); } }
附上pom文件
<project xmlns="http://maven.apache.org/pom/4.0.0" xmlns:xsi="http://www.w3.org/2001/xmlschema-instance" xsi:schemalocation="http://maven.apache.org/pom/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> <modelversion>4.0.0</modelversion> <groupid>com.rd</groupid> <artifactid>parser</artifactid> <packaging>jar</packaging> <version>1.0-snapshot</version> <name>myparser</name> <url>http://maven.apache.org</url> <dependencies> <dependency> <groupid>dom4j</groupid> <artifactid>dom4j</artifactid> <version>1.6.1</version> </dependency> <dependency> <groupid>opensymphony</groupid> <artifactid>ognl</artifactid> <version>2.6.11</version> </dependency> <dependency> <groupid>commons-collections</groupid> <artifactid>commons-collections</artifactid> <version>3.2.1</version> </dependency> <dependency> <groupid>commons-lang</groupid> <artifactid>commons-lang</artifactid> <version>2.6</version> </dependency> <dependency> <groupid>junit</groupid> <artifactid>junit</artifactid> <version>3.8.1</version> <scope>test</scope> </dependency> </dependencies> <build> <resources> <resource> <directory>src/main/java</directory> <includes> <include>**/*.xml</include> </includes> </resource> <resource> <directory>src/main/resources</directory> <includes> <include>**/*</include> </includes> </resource> </resources> <testresources> <testresource> <directory>${project.basedir}/src/test/java</directory> </testresource> <testresource> <directory>${project.basedir}/src/test/resources</directory> </testresource> </testresources> <plugins> <plugin> <groupid>org.apache.maven.plugins</groupid> <artifactid>maven-compiler-plugin</artifactid> <version>3.1</version> <configuration> <source>1.8</source> <target>1.8</target> <encoding>utf-8</encoding> </configuration> </plugin> </plugins> </build> </project>
以上这篇自己动手实现mybatis动态sql的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
上一篇: 设计模式(一)单例模式详解
下一篇: java外卖订餐系统小项目