欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Spring动态创建bean 博客分类: javaspringiocbean javaspringiocbean 

程序员文章站 2024-02-06 19:10:10
...
最近有个项目场景,多垂类支持,大体业务流程相同,只是一些业务规则的校验参数不同。解决思路是将业务参数作为类的属性,然后创建垂类数量个实例,去处理不同垂类的业务。

看了spring ioc部分的代码,个人感觉在spring完成bean创建的过程后,做一个类实现ApplicationContextAware接口,然后克隆多个需要的BeanDefinition,附不同的业务参数属性值的方式比较讨巧。新增加的BeanDefinition会在getBean的过程中,由spring创建。

下面分两部分介绍:
1、动态创建bean的代码实现
2、spring的ioc源码解读,这部分放到另外一篇博客http://mazhen2010.iteye.com/blog/2283773
<spring.version>4.0.6.RELEASE</spring.version>

【动态创建bean的代码实现】
1、创建一个实现ApplicationContextAware接口的类,然后获取DefaultListableBeanFactory
    private void setSpringFactory(ApplicationContext applicationContext) {

        if (applicationContext instanceof AbstractRefreshableApplicationContext) {
            // suit both XmlWebApplicationContext and ClassPathXmlApplicationContext
            AbstractRefreshableApplicationContext springContext = (AbstractRefreshableApplicationContext) applicationContext;
            if (!(springContext.getBeanFactory() instanceof DefaultListableBeanFactory)) {
                LOGGER.error("No suitable bean factory! The current factory class is {}",
                        springContext.getBeanFactory().getClass());
            }
            springFactory = (DefaultListableBeanFactory) springContext.getBeanFactory();
        } else if (applicationContext instanceof GenericApplicationContext) {
            // suit GenericApplicationContext
            GenericApplicationContext springContext = (GenericApplicationContext) applicationContext;
            springFactory = springContext.getDefaultListableBeanFactory();
        } else {
            LOGGER.error("No suitable application context! The current context class is {}",
                    applicationContext.getClass());
        }
    }


2、定义注解,以找到需要克隆的BeaDefinition和需要赋值的属性
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Component
public @interface TemplateService {

    //服务名称
    String serviceName();
    //服务实现名称
    String value() default "";
}

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface TemplateBizParam {
}

@TemplateService(serviceName = "demoService", value = "demoServiceImpl")
public class DemoServiceImpl extends AbstractServiceImpl implements DemoService {

    @TemplateBizParam
    private String noVisitDays;

    @Override
    public void doDemo(Long poiId) {
        StringBuilder builder = new StringBuilder("doDemo").append("//").append("poiId:").append(poiId);
        builder.append("//").append(noVisitDays).append("//").append(getExtendFields()).append("//");
        builder.append("abc:").append(getExtendField("abc"));
        System.out.println(builder.toString());
    }

    @Override
    public void doDemos(List<Long> poiIds) {
        System.out.println("poiIds" + poiIds + "; noVisitDays:" + noVisitDays);
    }

}


3、从垂类模板中获取需要动态创建的bean信息,然后注册BeanDefinition
    private void registerBeanDefinition(String templateId, ServiceEntity serviceEntity) {

        try {
            if (springFactory.containsBeanDefinition(serviceEntity.getImplName())) {
                //step1: 注入多个实例
                String beanKey = generateTemplateBeanName(templateId, serviceEntity.getServiceName());
                BeanDefinition beanDefinition = springFactory.getBeanDefinition(serviceEntity.getImplName());
                String className = beanDefinition.getBeanClassName();
                Class c = null;
                try {
                    c = Class.forName(className);
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }

                BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.rootBeanDefinition(className);
                beanDefinitionBuilder.getBeanDefinition().setAttribute("id", beanKey);

                springFactory.registerBeanDefinition(
                        beanKey, beanDefinitionBuilder.getBeanDefinition());
                LOGGER.info("Register bean definition successfully. beanName:{}, implName:{}",
                        generateTemplateBeanName(templateId, serviceEntity.getServiceName()), serviceEntity.getImplName());

                //step2: 为实例自动化注入属性
                Object bean = springFactory.getBean(beanKey, c);
                injectParamVaules(bean, c, serviceEntity);
            }

        } catch (NoSuchBeanDefinitionException ex) {
            LOGGER.info("No bean definition in spring factory. implName:{}", serviceEntity.getImplName());
        } catch (BeanDefinitionStoreException ex) {
            LOGGER.info("Register bean definition wrong. beanName:{}, implName:{}",
                    generateTemplateBeanName(templateId, serviceEntity.getServiceName()), serviceEntity.getImplName());
        }
    }

    private <T> void injectParamVaules(Object bean, Class<T> requiredType, ServiceEntity serviceEntity) {

        if (requiredType.isAnnotationPresent(TemplateService.class)) {
            Field[] fields = requiredType.getDeclaredFields(); //获取类的所有属性
            for (Field field : fields) {
                // 注入业务参数
                if (field.isAnnotationPresent(TemplateBizParam.class)) {
                    field.setAccessible(true);
                    try {
                        if ((serviceEntity.getBizParamMap() != null) && (serviceEntity.getBizParamMap().containsKey(field.getName()))) {
                            field.set(bean, serviceEntity.getBizParamMap().get(field.getName()));
                            LOGGER.info("inject biz param value successfully, paramName = {}, value = {}", field.getName(), serviceEntity.getBizParamMap().get(field.getName()));
                        }
                    } catch (IllegalAccessException e) {
                        LOGGER.error("inject biz param failed. {}", e.getMessage());
                        e.printStackTrace();
                    }
                }
            }

            Class<AbstractService> superClass = getSuperClass(requiredType);
            if(superClass != null) {
                Field[] superFields = superClass.getDeclaredFields(); //获取类的所有属性
                for (Field field : superFields) {
                    // 注入扩展字段
                    if (field.isAnnotationPresent(TemplateExtendFields.class)) {
                        field.setAccessible(true);
                        try {
                            if(serviceEntity.getExtendFields() != null){
                                field.set(bean, serviceEntity.getExtendFields());
                                LOGGER.info("inject extend fields successfully, extendFields = {}", serviceEntity.getExtendFields());
                            }
                        } catch (IllegalAccessException e) {
                            LOGGER.error("inject extend fields failed. {}", e.getMessage());
                            e.printStackTrace();
                        }
                    }
                }
            }


        }
    }


4、定义一个Context继承AbstractServiceContext,实现运行时根据策略,选取所需的业务实例进行处理
@Service("demoService")
public class DemoServiceContext extends AbstractServiceContext implements DemoService {

    @Override
    public void doDemo(Long poiId) {
        getServiceImpl(poiId, DemoService.class).doDemo(poiId);
    }

}

/**
 * 服务上下文抽象类,负责具体服务实现类的策略选择和扩展字段传递.
 * User: mazhen01
 * Date: 2016/3/3
 * Time: 10:14
 */
public abstract class AbstractServiceContext {

    @Resource
    private TemplateBeanFactory templateBeanFactory;

    @Autowired
    public TemplateFunction templateFunction;

    // 当前线程使用的beanName
    private ThreadLocal<String> currentTemplateBeanName = new ThreadLocal<String>();

    private static final Logger LOGGER = LoggerFactory.getLogger(AbstractServiceContext.class);

    /**
     * 根据POI所属行业,获取服务实例
     *
     * @param poiId poiId
     * @param clazz 服务接口
     * @param <T>   实例类型
     * @return
     * @throws AnnotationException
     * @throws BeansException
     */
    protected <T> T getServiceImpl(Long poiId, Class<T> clazz) throws AnnotationException, BeansException {
        String serviceName = getServiceName();
        String templateId = templateFunction.getTemplateId(poiId, serviceName);
        if (templateId == null) {
            LOGGER.error("templateId is null. No templateId id configured for poiId = {}.", poiId);
            throw new TemplateException("templateId is null, can not find templateId.");
        }
        currentTemplateBeanName.set(TemplateBeanFactory.generateTemplateBeanName(templateId, serviceName));
        return templateBeanFactory.getBean(TemplateBeanFactory.generateTemplateBeanName(templateId, serviceName), clazz);
    }

    protected <T> T getServiceImpl(List<Long> poiIds, Class<T> clazz) throws AnnotationException, BeansException {
        if (CollectionUtils.isEmpty(poiIds)) {
            LOGGER.error("poiIds List is null");
            throw new TemplateException("poiIds is null.");
        }
        Long poiId = poiIds.get(0);
        return getServiceImpl(poiId, clazz);
    }

    /**
     * 根据beanName,获取服务实例
     *
     * @param templateBeanName beanName
     * @param clazz            服务接口
     * @param <T>              实例类型
     * @return
     * @throws AnnotationException
     * @throws BeansException
     */
    protected <T> T getServiceImpl(String templateBeanName, Class<T> clazz) throws AnnotationException, BeansException {
        return templateBeanFactory.getBean(templateBeanName, clazz);
    }

    /**
     * 根据POI所属行业,获取服务实例的扩展字段列表
     *
     * @param poiId
     * @return
     */
    public List<String> getExtendFields(Long poiId) {
        AbstractServiceImpl abstractService = getServiceImpl(poiId, AbstractServiceImpl.class);

        if (abstractService == null || CollectionUtils.isEmpty(abstractService.getExtendFields())) {
            Lists.newArrayList();
        }

        return abstractService.getExtendFields();
    }

    /**
     * 根据POI所属行业,设置服务实例所需要的扩展字段的具体值
     *
     * @param poiId   poiId
     * @param request 用户请求
     */
    public void setExtendField(Long poiId, HttpServletRequest request) {

        if (request == null) {
            return;
        }

        AbstractServiceImpl abstractService = getServiceImpl(poiId, AbstractServiceImpl.class);

        if (abstractService == null || CollectionUtils.isEmpty(abstractService.getExtendFields())) {
            return;
        }

        for (String field : abstractService.getExtendFields()) {
            setExtendField(field, request.getAttribute(field));
        }
    }

    /**
     * 对扩展字段进行赋值
     *
     * @param field 字段名
     * @param value 值
     */
    public void setExtendField(String field, Object value) {
        if (currentTemplateBeanName == null || StringUtils.isEmpty(currentTemplateBeanName.get())) {
            return;
        }
        AbstractServiceImpl abstractService = getServiceImpl(currentTemplateBeanName.get(), AbstractServiceImpl.class);
        abstractService.getExtendFieldMap().put(field, value);
    }

    protected String getServiceName() throws AnnotationException {

        Class serviceClass = this.getClass();

        if (serviceClass.isAnnotationPresent(Service.class)) {
            Service service = this.getClass().getAnnotation(Service.class);
            if (service != null) {
                return service.value();
            }
            throwException("Has no Service annotation.");
        }

        if (serviceClass.isAnnotationPresent(Component.class)) {
            Component component = this.getClass().getAnnotation(Component.class);
            if (component != null) {
                return component.value();
            }
            throwException("Has no Component annotation.");
        }

        LOGGER.error("Has no annotation.");
        return null;
    }

    /**
     * 根据品类模板,对poiId进行分组
     *
     * @param poiIds
     * @return
     */
    public Map<Long, List<Long>> groupPoiIds(List<Long> poiIds) {
        Map<Long, List<Long>> map = null;
        map = templateFunction.groupPoiIds(poiIds);
        return map;
    }

    private void throwException(String message) throws AnnotationException {
        message = this.getClass() + "||" + message;
        LOGGER.error(message);
        throw new AnnotationException(message);
    }

}


5、在springContext.xml中声明TemplateBeanFactory
<bean class="com.baidu.nuomi.tpl.spring.TemplateBeanFactory"/>

TemplateBeanFactory的完整代码,包括模板变化时的刷新
/**
 * Bean工厂,创建在模板中定义的服务实例,填充业务参数和扩展字段
 * 定时刷新,如发现模板定义中的服务有变化,则刷新spring上下文中的实例.
 * User: mazhen01
 * Date: 2016/3/1
 * Time: 16:46
 */
public class TemplateBeanFactory implements ApplicationContextAware {

    private DefaultListableBeanFactory springFactory;

    private static final Logger LOGGER = LoggerFactory.getLogger(TemplateBeanFactory.class);

    @Autowired
    TemplateFunction templateFunction;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        setSpringFactory(applicationContext);
        templateFunction.init();
        loadTemplateBeanDefinitions(templateFunction.getAllTemplateEntity());
    }

    /**
     * 刷新模板bean
     */
    public void refreshTemplateBeans(List<TemplateEntity> changedTemplates) {
        LOGGER.info("Refresh changed template beans start.");
        if(CollectionUtils.isEmpty(changedTemplates)){
            LOGGER.info("no template beans is changed");
            return;
        }
        destroyTemplateBeans(changedTemplates);
        loadTemplateBeanDefinitions(changedTemplates);
        LOGGER.info("Refresh changed template beans end.");
    }

    /**
     * 根据应用使用的不同applicationContext,获取BeanFactory
     *
     * @param applicationContext 应用使用的applicationContext
     */
    private void setSpringFactory(ApplicationContext applicationContext) {

        if (applicationContext instanceof AbstractRefreshableApplicationContext) {
            // suit both XmlWebApplicationContext and ClassPathXmlApplicationContext
            AbstractRefreshableApplicationContext springContext = (AbstractRefreshableApplicationContext) applicationContext;
            if (!(springContext.getBeanFactory() instanceof DefaultListableBeanFactory)) {
                LOGGER.error("No suitable bean factory! The current factory class is {}",
                        springContext.getBeanFactory().getClass());
            }
            springFactory = (DefaultListableBeanFactory) springContext.getBeanFactory();
        } else if (applicationContext instanceof GenericApplicationContext) {
            // suit GenericApplicationContext
            GenericApplicationContext springContext = (GenericApplicationContext) applicationContext;
            springFactory = springContext.getDefaultListableBeanFactory();
        } else {
            LOGGER.error("No suitable application context! The current context class is {}",
                    applicationContext.getClass());
        }
    }

    /**
     * 将模板中定义的service,填充业务参数和扩展字段,添加到BeanFactory的definition中
     */
    private void loadTemplateBeanDefinitions(List<TemplateEntity> templateEntityList) {
        if (CollectionUtils.isEmpty(templateEntityList)) {
            LOGGER.warn("");
            return;
        }
        for (TemplateEntity templateEntity : templateEntityList) {
            if (templateEntity == null || CollectionUtils.isEmpty(templateEntity.getServiceList())) {
                continue;
            }
            Long templateId = templateEntity.getIndustryId();
            for (ServiceEntity serviceEntity : templateEntity.getServiceList()) {
                registerBeanDefinition(templateId.toString(), serviceEntity);
            }
        }
    }

    /**
     * 根据service信息,创建BeanDefinition
     *
     * @param templateId    模板ID
     * @param serviceEntity service信息
     */
    private void registerBeanDefinition(String templateId, ServiceEntity serviceEntity) {

        try {
            if (springFactory.containsBeanDefinition(serviceEntity.getImplName())) {
                //step1: 注入多个实例
                String beanKey = generateTemplateBeanName(templateId, serviceEntity.getServiceName());
                BeanDefinition beanDefinition = springFactory.getBeanDefinition(serviceEntity.getImplName());
                String className = beanDefinition.getBeanClassName();
                Class c = null;
                try {
                    c = Class.forName(className);
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }

                BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.rootBeanDefinition(className);
                beanDefinitionBuilder.getBeanDefinition().setAttribute("id", beanKey);

                springFactory.registerBeanDefinition(
                        beanKey, beanDefinitionBuilder.getBeanDefinition());
                LOGGER.info("Register bean definition successfully. beanName:{}, implName:{}",
                        generateTemplateBeanName(templateId, serviceEntity.getServiceName()), serviceEntity.getImplName());

                //step2: 为实例自动化注入属性
                Object bean = springFactory.getBean(beanKey, c);
                injectParamVaules(bean, c, serviceEntity);
            }

        } catch (NoSuchBeanDefinitionException ex) {
            LOGGER.info("No bean definition in spring factory. implName:{}", serviceEntity.getImplName());
        } catch (BeanDefinitionStoreException ex) {
            LOGGER.info("Register bean definition wrong. beanName:{}, implName:{}",
                    generateTemplateBeanName(templateId, serviceEntity.getServiceName()), serviceEntity.getImplName());
        }
    }

    /**
     * 为bean实例注入业务参数和扩展字段
     *
     * @param bean
     * @param requiredType
     * @param serviceEntity
     * @param <T>
     */
    private <T> void injectParamVaules(Object bean, Class<T> requiredType, ServiceEntity serviceEntity) {

        if (requiredType.isAnnotationPresent(TemplateService.class)) {
            Field[] fields = requiredType.getDeclaredFields(); //获取类的所有属性
            for (Field field : fields) {
                // 注入业务参数
                if (field.isAnnotationPresent(TemplateBizParam.class)) {
                    field.setAccessible(true);
                    try {
                        if ((serviceEntity.getBizParamMap() != null) && (serviceEntity.getBizParamMap().containsKey(field.getName()))) {
                            field.set(bean, serviceEntity.getBizParamMap().get(field.getName()));
                            LOGGER.info("inject biz param value successfully, paramName = {}, value = {}", field.getName(), serviceEntity.getBizParamMap().get(field.getName()));
                        }
                    } catch (IllegalAccessException e) {
                        LOGGER.error("inject biz param failed. {}", e.getMessage());
                        e.printStackTrace();
                    }
                }
            }

            Class<AbstractService> superClass = getSuperClass(requiredType);
            if(superClass != null) {
                Field[] superFields = superClass.getDeclaredFields(); //获取类的所有属性
                for (Field field : superFields) {
                    // 注入扩展字段
                    if (field.isAnnotationPresent(TemplateExtendFields.class)) {
                        field.setAccessible(true);
                        try {
                            if(serviceEntity.getExtendFields() != null){
                                field.set(bean, serviceEntity.getExtendFields());
                                LOGGER.info("inject extend fields successfully, extendFields = {}", serviceEntity.getExtendFields());
                            }
                        } catch (IllegalAccessException e) {
                            LOGGER.error("inject extend fields failed. {}", e.getMessage());
                            e.printStackTrace();
                        }
                    }
                }
            }


        }
    }

    private Class<AbstractService> getSuperClass(Class clazz) {
        if (!AbstractService.class.isAssignableFrom(clazz)) {
            LOGGER.info("super class is null");
            return null;
        }
        Class<? extends AbstractService> superClass = clazz.getSuperclass();
        if (AbstractService.class != superClass) {
            superClass = getSuperClass(superClass);
        }
        return (Class<AbstractService>) superClass;
    }

    /***
     * 销毁模板bean
     */
    private void destroyTemplateBeans(List<TemplateEntity> changedTemplates) {

        if (CollectionUtils.isEmpty(changedTemplates)) {
            LOGGER.warn("");
            return;
        }
        for (TemplateEntity templateEntity : changedTemplates) {
            if (templateEntity == null || CollectionUtils.isEmpty(templateEntity.getServiceList())) {
                continue;
            }
            String templateId = templateEntity.getIndustryId().toString();
            for (ServiceEntity serviceEntity : templateEntity.getServiceList()) {

                if (springFactory.containsSingleton(generateTemplateBeanName(templateId, serviceEntity.getServiceName()))) {
//                    springFactory.destroySingleton(generateTemplateBeanName(templateId, serviceEntity.getServiceName()));  不需要显示的destroy方法,removeBeanDefinition中已调用此方法了
                    springFactory.removeBeanDefinition(generateTemplateBeanName(templateId, serviceEntity.getServiceName()));
                    LOGGER.info("destroy template beans successfully for beanName = {}", generateTemplateBeanName(templateId, serviceEntity.getServiceName()));
                }
            }
        }
    }


    /**
     * 从springFactory中获取bean
     *
     * @param name         bean名称
     * @param requiredType bean类型
     * @param <T>
     * @return
     * @throws BeansException
     */
    public <T> T getBean(String name, Class<T> requiredType) throws BeansException {
        return springFactory.getBean(name, requiredType);
    }

    ;

    /**
     * 生成模板service实例名称
     *
     * @param templateId  模板ID
     * @param serviceName service名称
     * @return
     */
    public static final String generateTemplateBeanName(String templateId, String serviceName) {
        StringBuilder builder = new StringBuilder(serviceName);
        builder.append("_");
        builder.append(templateId);
        return builder.toString();
    }
}