欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

通过反射注解批量插入数据到DB的实现方法

程序员文章站 2024-02-26 10:05:34
批量导入思路 最近遇到一个需要批量导入数据问题。后来考虑运用反射做成一个工具类,思路是首先定义注解接口,在bean类上加注解,运行时通过反射获取传入bean的注解,自动生...

批量导入思路

最近遇到一个需要批量导入数据问题。后来考虑运用反射做成一个工具类,思路是首先定义注解接口,在bean类上加注解,运行时通过反射获取传入bean的注解,自动生成需要插入db的sql,根据设置的参数值批量提交。不需要写具体的sql,也没有dao的实现,这样一来批量导入的实现就和具体的数据库表彻底解耦。实际批量执行的sql如下:

insert into company_candidate(company_id,user_id,card_id,facebook_id,type,create_time,weight,score) values (?,?,?,?,?,?,?,?) on duplicate key update type=?,weight=?,score=?

第一步,定义注解接口

注解接口table中定义了数据库名和表名。retentionpolicy.runtime表示该注解保存到运行时,因为我们需要在运行时,去读取注解参数来生成具体的sql。

@documented
@retention(retentionpolicy.runtime)
@target(elementtype.type)
public @interface table {
  /**
   * 表名
   * @return
   */
  string tablename() default "";
  /**
   * 数据库名称
   * @return
   */
  string dbname();
}

注解接口tablefield中定义了数据库表名的各个具体字段名称,以及该字段是否忽略(忽略的话就会以数据库表定义默认值填充,db非null字段的注解不允许出现把ignore注解设置为true)。update注解是在主键在db重复时,需要更新的字段。

@documented
@retention(retentionpolicy.runtime)
@target(elementtype.field)
public @interface tablefield {
  /**
   * 对应数据库字段名称
   * @return
   */
  string fieldname() default "";
  /**
   * 是否是主键
   * @return
   */
  boolean pk() default false;
  /**
   * 是否忽略该字段
   * @return
   */
  boolean ignore() default false;
  /**
   * 当数据存在时,是否更新该字段
   * @return
   */
  boolean update() default false;
}

第二步,给bean添加注解

给bean添加注解(为了简洁省略了import和set/get方法以及其他属性),@tablefield(fieldname = "company_id")表示companyid字段对应db表的字段名为"company_id",其中updatetime属性的注解含有ignore=true,表示该属性值会被忽略。另外serialversionuid属性由于没有@tablefield注解,在更新db时也会被忽略。

代码如下:

@table(dbname = "company", tablename = "company_candidate")
public class companycandidatemodel implements serializable{
 private static final long serialversionuid = -1234554321773322135l;
 @tablefield(fieldname = "company_id")
 private int companyid;
 @tablefield(fieldname = "user_id")
 private int userid;
 //名片id
 @tablefield(fieldname = "card_id")
 private int cardid;
 //facebookid
 @tablefield(fieldname = "facebook_id")
 private long facebookid;
  @tablefield(fieldname="type", update = true)
 private int type;
 @tablefield(fieldname = "create_time")
 private date createtime;
 @tablefield(fieldname = "update_time", ignore=true)
 private date updatetime;
 // 权重
  @tablefield(fieldname="weight", update = true)
 private int weight;
 // 分值
  @tablefield(fieldname="score", update = true)
 private double score;

第三步,读取注解的反射工具类

读取第二步bean类的注解的反射工具类。利用反射getannotation(tablefield.class)读取注解信息,为批量sql的拼接最好准备。

gettablebeanfieldmap()方法里生成一个linkedhashmap对象,是为了保证生成插入sql的field顺序,之后也能按同样的顺序给参数赋值,避免错位。getsqlparamfields()方法也类似,是为了给preparedstatement设置参数用。

代码如下:

public class reflectutil {
  /**
   * <class,<表定义field名,bean定义field>>的map缓存
   */
  private static final map<class<?>, map<string field="">> classtablebeanfieldmap = new hashmap<class<?>, map<string field="">>();
  // 用来按顺序填充sql参数,其中存储的field和classtablebeanfieldmap保存同样的顺序,但数量多出on duplicate key update部分field
  private static final map<class<?>, list<field>> sqlparamfieldsmap = new hashmap<class<?>, list<field>>(); 
  private reflectutil(){};
  /**
   * 获取该类上所有@tablefield注解,且没有忽略的字段的map。
   * <br />返回一个有序的linkedhashmap类型
   * <br />其中key为db表中的字段,value为bean类里的属性field对象
   * @param clazz
   * @return
   */
  public static map<string field=""> gettablebeanfieldmap(class<?> clazz) {
   // 从缓存获取
   map<string field=""> fieldsmap = classtablebeanfieldmap.get(clazz);
   if (fieldsmap == null) {
   fieldsmap = new linkedhashmap<string field="">();
      for (field field : clazz.getdeclaredfields()) {// 获得所有声明属性数组的一个拷贝
       tablefield annotation = field.getannotation(tablefield.class);
        if (annotation != null && !annotation.ignore() && !"".equals(annotation.fieldname())) {
          field.setaccessible(true);// 方便后续获取私有域的值
         fieldsmap.put(annotation.fieldname(), field);
        }
  }
      // 放入缓存
      classtablebeanfieldmap.put(clazz, fieldsmap);
   }
   return fieldsmap;
  }
  /**
   * 获取该类上所有@tablefield注解,且没有忽略的字段的map。on duplicate key update后需要更新的字段追加在list最后,为了填充参数值准备
   * <br />返回一个有序的arraylist类型
   * <br />其中key为db表中的字段,value为bean类里的属性field对象
   * @param clazz
   * @return
   */
  public static list<field> getsqlparamfields(class<?> clazz) {
   // 从缓存获取
   list<field> sqlparamfields = sqlparamfieldsmap.get(clazz);
   if (sqlparamfields == null) {
   // 获取所有参数字段
     map<string field=""> fieldsmap = gettablebeanfieldmap(clazz);
   sqlparamfields = new arraylist<field>(fieldsmap.size() * 2);
     // sql后段on duplicate key update需要更新的字段
     list<field> updateparamfields = new arraylist<field>();
   iterator<entry<string field="">> iter = fieldsmap.entryset().iterator();
   while (iter.hasnext()) {
    entry<string field=""> entry = (entry<string field="">) iter.next();
    field field = entry.getvalue();
    // insert语句对应sql参数字段
    sqlparamfields.add(field);
        // on duplicate key update后面语句对应sql参数字段
        tablefield annotation = field.getannotation(tablefield.class);
    if (annotation != null && !annotation.ignore() && annotation.update()) {
    updateparamfields.add(field);
    }
   }
   sqlparamfields.addall(updateparamfields);
      // 放入缓存
   sqlparamfieldsmap.put(clazz, sqlparamfields);
   }
   return sqlparamfields;
  }
  /**
   * 获取表名,对象中使用@table的tablename来标记对应数据库的表名,若未标记则自动将类名转成小写
   * 
   * @param clazz
   * @return
   */
  public static string gettablename(class<?> clazz) {
    table table = clazz.getannotation(table.class);
    if (table != null && table.tablename() != null && !"".equals(table.tablename())) {
      return table.tablename();
    }
    // 当未配置@table的tablename,自动将类名转成小写
    return clazz.getsimplename().tolowercase();
  }
  /**
   * 获取数据库名,对象中使用@table的dbname来标记对应数据库名
   * @param clazz
   * @return
   */
  public static string getdbname(class<?> clazz) {
    table table = clazz.getannotation(table.class);
    if (table != null && table.dbname() != null) {
      // 注解@table的dbname
      return table.dbname();
    }
    return "";
  }

第四步,生成sql语句

根据上一步的方法,生成真正执行的sql语句。

insert into company_candidate(company_id,user_id,card_id,facebook_id,type,create_time,weight,score) values (?,?,?,?,?,?,?,?) on duplicate key update type=?,weight=?,score=?

代码如下:

public class sqlutil {
  private static final char comma = ',';
  private static final char brackets_begin = '(';
  private static final char brackets_end = ')';
  private static final char question_mark = '?';
  private static final char equal_sign = '=';
  private static final string insert_begin = "insert into ";
  private static final string insert_valurs = " values ";
  private static final string duplicate_update = " on duplicate key update ";
  // 数据库表名和对应insertupdatesql的缓存
  private static final map<string string=""> tableinsertsqlmap = new hashmap<string string="">();
  /**
   * 获取插入的sql语句,对象中使用@tablefield的fieldname来标记对应数据库的列名,若未标记则忽略
   * 必须标记@tablefield(fieldname = "company_id")注解
   * @param tablename
   * @param fieldsmap
   * @return
   * @throws exception
   */
  public static string getinsertsql(string tablename, map<string field=""> fieldsmap) throws exception {
   string sql = tableinsertsqlmap.get(tablename);
   if (sql == null) {
   stringbuilder sbsql = new stringbuilder(300).append(insert_begin);
   stringbuilder sbvalue = new stringbuilder(insert_valurs);
   stringbuilder sbupdate = new stringbuilder(100).append(duplicate_update);
   sbsql.append(tablename);
   sbsql.append(brackets_begin);
   sbvalue.append(brackets_begin);
   iterator<entry<string field="">> iter = fieldsmap.entryset().iterator();
   while (iter.hasnext()) {
    entry<string field=""> entry = (entry<string field="">) iter.next();
    string tablefieldname = entry.getkey();
    field field = entry.getvalue();
    sbsql.append(tablefieldname);
    sbsql.append(comma);
    sbvalue.append(question_mark);
    sbvalue.append(comma);
    tablefield tablefield = field.getannotation(tablefield.class);
    if (tablefield != null && tablefield.update()) {
    sbupdate.append(tablefieldname);
    sbupdate.append(equal_sign);
    sbupdate.append(question_mark);
    sbupdate.append(comma);
    }
   }
   // 去掉最后的逗号
   sbsql.deletecharat(sbsql.length() - 1);
   sbvalue.deletecharat(sbvalue.length() - 1);
   sbsql.append(brackets_end);
   sbvalue.append(brackets_end);
   sbsql.append(sbvalue);
   if (!sbupdate.tostring().equals(duplicate_update)) {
    sbupdate.deletecharat(sbupdate.length() - 1);
    sbsql.append(sbupdate);
   }
   sql = sbsql.tostring();
   tableinsertsqlmap.put(tablename, sql);
   }
    return sql;
  }

第五步,批量sql插入实现

从连接池获取connection,sqlutil.getinsertsql()获取执行的sql语句,根据sqlparamfields来为preparedstatement填充参数值。当循环的值集合到达batchnum时就提交一次。

代码如下:

  /**
   * 批量插入,如果主键一致则更新。结果返回更新记录条数<br />
   * @param datalist
   *      要插入的对象list
   * @param batchnum
   *      每次批量插入条数
   * @return 更新记录条数
   */
  public int batchinsertsql(list<? extends object> datalist, int batchnum) throws exception {
   if (datalist == null || datalist.isempty()) {
   return 0;
   }
    class<?> clazz = datalist.get(0).getclass();
    string tablename = reflectutil.gettablename(clazz);
    string dbname = reflectutil.getdbname(clazz);
    connection connnection = null;
    preparedstatement preparedstatement = null;
    // 获取所有需要更新到db的属性域
    map<string field=""> fieldsmap = reflectutil.gettablebeanfieldmap(datalist.get(0).getclass());
    // 根据需要插入更新的字段生成sql语句
    string sql = sqlutil.getinsertsql(tablename, fieldsmap);
    log.debug("prepare to start batch operation , sql = " + sql + " , dbname = " + dbname);
    // 获取和sql语句同样顺序的填充参数fields
    list<field> sqlparamfields = reflectutil.getsqlparamfields(datalist.get(0).getclass());
    // 最终更新结果条数
    int result = 0;
    int parameterindex = 1;// sql填充参数开始位置为1
    // 执行错误的对象
    list<object> errorsrecords = new arraylist</object><object>(batchnum);//指定数组大小
    // 计数器,batchnum提交后内循环累计次数
    int innercount = 0;
    try {
      connnection = this.getconnection(dbname);
      // 设置非自动提交
      connnection.setautocommit(false);
      preparedstatement = connnection.preparestatement(sql);
      // 当前操作的对象
      object object = null;
      int totalrecordcount = datalist.size();
      for (int current = 0; current < totalrecordcount; current++) {
        innercount++;
        object = datalist.get(current);
       parameterindex = 1;// 开始参数位置为1
       for(field field : sqlparamfields) {
       // 放入insert语句对应sql参数
          preparedstatement.setobject(parameterindex++, field.get(object));
       }
       errorsrecords.add(object);
        preparedstatement.addbatch();
        // 达到批量次数就提交一次
        if (innercount >= batchnum || current >= totalrecordcount - 1) {
          // 执行batch操作
          preparedstatement.executebatch();
          preparedstatement.clearbatch();
          // 提交
          connnection.commit();
          // 记录提交成功条数
          result += innercount;
          innercount = 0;
          errorsrecords.clear();
        }
        // 尽早让gc回收
        datalist.set(current, null);
      }
      return result;
    } catch (exception e) {
      // 失败后处理方法
      callbackimpl.getinstance().exectuer(sql, errorsrecords, e);
      batchdbexception be = new batchdbexception("batch run error , dbname = " + dbname + " sql = " + sql, e);
      be.initcause(e);
      throw be;
    } finally {
      // 关闭
      if (preparedstatement != null) {
       preparedstatement.clearbatch();
        preparedstatement.close();
      }
      if (connnection != null)
        connnection.close();
    }
  }

最后,批量工具类使用例子

在mysql下的开发环境下测试,5万条数据大概13秒。

list<companycandidatemodel> updatedatalist = new arraylist<companycandidatemodel>(50000);
// ...为updatedatalist填充数据
int result = batchjdbctemplate.batchinsertsql(updatedatalist, 50);

总结

以上就是这篇文章的全部内容了,希望本文的内容对大家的学习或者工作具有一定的参考学习价值,谢谢大家对的支持。如果你想了解更多相关内容请查看下面相关链接