通过反射注解批量插入数据到DB的实现方法
批量导入思路
最近遇到一个需要批量导入数据问题。后来考虑运用反射做成一个工具类,思路是首先定义注解接口,在bean类上加注解,运行时通过反射获取传入bean的注解,自动生成需要插入db的sql,根据设置的参数值批量提交。不需要写具体的sql,也没有dao的实现,这样一来批量导入的实现就和具体的数据库表彻底解耦。实际批量执行的sql如下:
insert into company_candidate(company_id,user_id,card_id,facebook_id,type,create_time,weight,score) values (?,?,?,?,?,?,?,?) on duplicate key update type=?,weight=?,score=?
第一步,定义注解接口
注解接口table中定义了数据库名和表名。retentionpolicy.runtime表示该注解保存到运行时,因为我们需要在运行时,去读取注解参数来生成具体的sql。
@documented @retention(retentionpolicy.runtime) @target(elementtype.type) public @interface table { /** * 表名 * @return */ string tablename() default ""; /** * 数据库名称 * @return */ string dbname(); }
注解接口tablefield中定义了数据库表名的各个具体字段名称,以及该字段是否忽略(忽略的话就会以数据库表定义默认值填充,db非null字段的注解不允许出现把ignore注解设置为true)。update注解是在主键在db重复时,需要更新的字段。
@documented @retention(retentionpolicy.runtime) @target(elementtype.field) public @interface tablefield { /** * 对应数据库字段名称 * @return */ string fieldname() default ""; /** * 是否是主键 * @return */ boolean pk() default false; /** * 是否忽略该字段 * @return */ boolean ignore() default false; /** * 当数据存在时,是否更新该字段 * @return */ boolean update() default false; }
第二步,给bean添加注解
给bean添加注解(为了简洁省略了import和set/get方法以及其他属性),@tablefield(fieldname = "company_id")
表示companyid字段对应db表的字段名为"company_id
",其中updatetime属性的注解含有ignore=true
,表示该属性值会被忽略。另外serialversionuid属性由于没有@tablefield注解,在更新db时也会被忽略。
代码如下:
@table(dbname = "company", tablename = "company_candidate") public class companycandidatemodel implements serializable{ private static final long serialversionuid = -1234554321773322135l; @tablefield(fieldname = "company_id") private int companyid; @tablefield(fieldname = "user_id") private int userid; //名片id @tablefield(fieldname = "card_id") private int cardid; //facebookid @tablefield(fieldname = "facebook_id") private long facebookid; @tablefield(fieldname="type", update = true) private int type; @tablefield(fieldname = "create_time") private date createtime; @tablefield(fieldname = "update_time", ignore=true) private date updatetime; // 权重 @tablefield(fieldname="weight", update = true) private int weight; // 分值 @tablefield(fieldname="score", update = true) private double score;
第三步,读取注解的反射工具类
读取第二步bean类的注解的反射工具类。利用反射getannotation(tablefield.class)
读取注解信息,为批量sql的拼接最好准备。
gettablebeanfieldmap()
方法里生成一个linkedhashmap对象,是为了保证生成插入sql的field顺序,之后也能按同样的顺序给参数赋值,避免错位。getsqlparamfields()
方法也类似,是为了给preparedstatement设置参数用。
代码如下:
public class reflectutil { /** * <class,<表定义field名,bean定义field>>的map缓存 */ private static final map<class<?>, map<string field="">> classtablebeanfieldmap = new hashmap<class<?>, map<string field="">>(); // 用来按顺序填充sql参数,其中存储的field和classtablebeanfieldmap保存同样的顺序,但数量多出on duplicate key update部分field private static final map<class<?>, list<field>> sqlparamfieldsmap = new hashmap<class<?>, list<field>>(); private reflectutil(){}; /** * 获取该类上所有@tablefield注解,且没有忽略的字段的map。 * <br />返回一个有序的linkedhashmap类型 * <br />其中key为db表中的字段,value为bean类里的属性field对象 * @param clazz * @return */ public static map<string field=""> gettablebeanfieldmap(class<?> clazz) { // 从缓存获取 map<string field=""> fieldsmap = classtablebeanfieldmap.get(clazz); if (fieldsmap == null) { fieldsmap = new linkedhashmap<string field="">(); for (field field : clazz.getdeclaredfields()) {// 获得所有声明属性数组的一个拷贝 tablefield annotation = field.getannotation(tablefield.class); if (annotation != null && !annotation.ignore() && !"".equals(annotation.fieldname())) { field.setaccessible(true);// 方便后续获取私有域的值 fieldsmap.put(annotation.fieldname(), field); } } // 放入缓存 classtablebeanfieldmap.put(clazz, fieldsmap); } return fieldsmap; } /** * 获取该类上所有@tablefield注解,且没有忽略的字段的map。on duplicate key update后需要更新的字段追加在list最后,为了填充参数值准备 * <br />返回一个有序的arraylist类型 * <br />其中key为db表中的字段,value为bean类里的属性field对象 * @param clazz * @return */ public static list<field> getsqlparamfields(class<?> clazz) { // 从缓存获取 list<field> sqlparamfields = sqlparamfieldsmap.get(clazz); if (sqlparamfields == null) { // 获取所有参数字段 map<string field=""> fieldsmap = gettablebeanfieldmap(clazz); sqlparamfields = new arraylist<field>(fieldsmap.size() * 2); // sql后段on duplicate key update需要更新的字段 list<field> updateparamfields = new arraylist<field>(); iterator<entry<string field="">> iter = fieldsmap.entryset().iterator(); while (iter.hasnext()) { entry<string field=""> entry = (entry<string field="">) iter.next(); field field = entry.getvalue(); // insert语句对应sql参数字段 sqlparamfields.add(field); // on duplicate key update后面语句对应sql参数字段 tablefield annotation = field.getannotation(tablefield.class); if (annotation != null && !annotation.ignore() && annotation.update()) { updateparamfields.add(field); } } sqlparamfields.addall(updateparamfields); // 放入缓存 sqlparamfieldsmap.put(clazz, sqlparamfields); } return sqlparamfields; } /** * 获取表名,对象中使用@table的tablename来标记对应数据库的表名,若未标记则自动将类名转成小写 * * @param clazz * @return */ public static string gettablename(class<?> clazz) { table table = clazz.getannotation(table.class); if (table != null && table.tablename() != null && !"".equals(table.tablename())) { return table.tablename(); } // 当未配置@table的tablename,自动将类名转成小写 return clazz.getsimplename().tolowercase(); } /** * 获取数据库名,对象中使用@table的dbname来标记对应数据库名 * @param clazz * @return */ public static string getdbname(class<?> clazz) { table table = clazz.getannotation(table.class); if (table != null && table.dbname() != null) { // 注解@table的dbname return table.dbname(); } return ""; }
第四步,生成sql语句
根据上一步的方法,生成真正执行的sql语句。
insert into company_candidate(company_id,user_id,card_id,facebook_id,type,create_time,weight,score) values (?,?,?,?,?,?,?,?) on duplicate key update type=?,weight=?,score=?
代码如下:
public class sqlutil { private static final char comma = ','; private static final char brackets_begin = '('; private static final char brackets_end = ')'; private static final char question_mark = '?'; private static final char equal_sign = '='; private static final string insert_begin = "insert into "; private static final string insert_valurs = " values "; private static final string duplicate_update = " on duplicate key update "; // 数据库表名和对应insertupdatesql的缓存 private static final map<string string=""> tableinsertsqlmap = new hashmap<string string="">(); /** * 获取插入的sql语句,对象中使用@tablefield的fieldname来标记对应数据库的列名,若未标记则忽略 * 必须标记@tablefield(fieldname = "company_id")注解 * @param tablename * @param fieldsmap * @return * @throws exception */ public static string getinsertsql(string tablename, map<string field=""> fieldsmap) throws exception { string sql = tableinsertsqlmap.get(tablename); if (sql == null) { stringbuilder sbsql = new stringbuilder(300).append(insert_begin); stringbuilder sbvalue = new stringbuilder(insert_valurs); stringbuilder sbupdate = new stringbuilder(100).append(duplicate_update); sbsql.append(tablename); sbsql.append(brackets_begin); sbvalue.append(brackets_begin); iterator<entry<string field="">> iter = fieldsmap.entryset().iterator(); while (iter.hasnext()) { entry<string field=""> entry = (entry<string field="">) iter.next(); string tablefieldname = entry.getkey(); field field = entry.getvalue(); sbsql.append(tablefieldname); sbsql.append(comma); sbvalue.append(question_mark); sbvalue.append(comma); tablefield tablefield = field.getannotation(tablefield.class); if (tablefield != null && tablefield.update()) { sbupdate.append(tablefieldname); sbupdate.append(equal_sign); sbupdate.append(question_mark); sbupdate.append(comma); } } // 去掉最后的逗号 sbsql.deletecharat(sbsql.length() - 1); sbvalue.deletecharat(sbvalue.length() - 1); sbsql.append(brackets_end); sbvalue.append(brackets_end); sbsql.append(sbvalue); if (!sbupdate.tostring().equals(duplicate_update)) { sbupdate.deletecharat(sbupdate.length() - 1); sbsql.append(sbupdate); } sql = sbsql.tostring(); tableinsertsqlmap.put(tablename, sql); } return sql; }
第五步,批量sql插入实现
从连接池获取connection,sqlutil.getinsertsql()
获取执行的sql语句,根据sqlparamfields来为preparedstatement填充参数值。当循环的值集合到达batchnum时就提交一次。
代码如下:
/** * 批量插入,如果主键一致则更新。结果返回更新记录条数<br /> * @param datalist * 要插入的对象list * @param batchnum * 每次批量插入条数 * @return 更新记录条数 */ public int batchinsertsql(list<? extends object> datalist, int batchnum) throws exception { if (datalist == null || datalist.isempty()) { return 0; } class<?> clazz = datalist.get(0).getclass(); string tablename = reflectutil.gettablename(clazz); string dbname = reflectutil.getdbname(clazz); connection connnection = null; preparedstatement preparedstatement = null; // 获取所有需要更新到db的属性域 map<string field=""> fieldsmap = reflectutil.gettablebeanfieldmap(datalist.get(0).getclass()); // 根据需要插入更新的字段生成sql语句 string sql = sqlutil.getinsertsql(tablename, fieldsmap); log.debug("prepare to start batch operation , sql = " + sql + " , dbname = " + dbname); // 获取和sql语句同样顺序的填充参数fields list<field> sqlparamfields = reflectutil.getsqlparamfields(datalist.get(0).getclass()); // 最终更新结果条数 int result = 0; int parameterindex = 1;// sql填充参数开始位置为1 // 执行错误的对象 list<object> errorsrecords = new arraylist</object><object>(batchnum);//指定数组大小 // 计数器,batchnum提交后内循环累计次数 int innercount = 0; try { connnection = this.getconnection(dbname); // 设置非自动提交 connnection.setautocommit(false); preparedstatement = connnection.preparestatement(sql); // 当前操作的对象 object object = null; int totalrecordcount = datalist.size(); for (int current = 0; current < totalrecordcount; current++) { innercount++; object = datalist.get(current); parameterindex = 1;// 开始参数位置为1 for(field field : sqlparamfields) { // 放入insert语句对应sql参数 preparedstatement.setobject(parameterindex++, field.get(object)); } errorsrecords.add(object); preparedstatement.addbatch(); // 达到批量次数就提交一次 if (innercount >= batchnum || current >= totalrecordcount - 1) { // 执行batch操作 preparedstatement.executebatch(); preparedstatement.clearbatch(); // 提交 connnection.commit(); // 记录提交成功条数 result += innercount; innercount = 0; errorsrecords.clear(); } // 尽早让gc回收 datalist.set(current, null); } return result; } catch (exception e) { // 失败后处理方法 callbackimpl.getinstance().exectuer(sql, errorsrecords, e); batchdbexception be = new batchdbexception("batch run error , dbname = " + dbname + " sql = " + sql, e); be.initcause(e); throw be; } finally { // 关闭 if (preparedstatement != null) { preparedstatement.clearbatch(); preparedstatement.close(); } if (connnection != null) connnection.close(); } }
最后,批量工具类使用例子
在mysql下的开发环境下测试,5万条数据大概13秒。
list<companycandidatemodel> updatedatalist = new arraylist<companycandidatemodel>(50000); // ...为updatedatalist填充数据 int result = batchjdbctemplate.batchinsertsql(updatedatalist, 50);
总结
以上就是这篇文章的全部内容了,希望本文的内容对大家的学习或者工作具有一定的参考学习价值,谢谢大家对的支持。如果你想了解更多相关内容请查看下面相关链接