欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

记一次用Java Stream Api的经历

程序员文章站 2022-03-10 21:45:20
...

最近有个项目需要用到推荐系统,弄了个简单的相似度推荐算法。

数据为:

记一次用Java Stream Api的经历

化简为:

public class Worker {
    /**
     * 用户编号
     */
    private long userId;
    /**
     * 期望城市
     */
    private String expectedCity;
    /**
     * 现在状态
     */
    private int status;
    /**
     * 最高学历
     */
    private String education;
    /**
     * 工作经验
     */
    private int experience;
    /**
     * 星座
     */
    private String constellation;
    /**
     * 年龄
     */
    private int age;
    /**
     * 籍贯
     */
    private String nativePlace;
    /**
     * 自我介绍
     */
    private String introduction;
    /**
     * 所在地区
     */
    private String location;
    /**省略get() set()**/
}

计算策略是:

1、数值越接近,值越大

2、数值相同,返回1,否则返回0

3、如果是与字符串有关的,例如(做饭做卫生,辅带宝宝,做饭好吃,做事麻利,为人干净利落 ,形象好。易沟通。有育婴师证。),则计算余弦距离,在这里没有做分词,因此将此内容的比重下降

算法如下:

public class ScoreCos {
    /**
     * 不分词 纯字符串计算
     * @param text1
     * @param text2
     * @return
     */
    public static double similarScoreCos(String text1, String text2){
        if(text1 == null || text2 == null){
            //只要有一个文本为null,规定相似度分值为0,表示完全不相等
            return 0.0;
        }else if("".equals(text1)&&"".equals(text2)) return 1.0;
        Set<Integer> ASII=new TreeSet<>();
        Map<Integer, Integer> text1Map=new HashMap<>();
        Map<Integer, Integer> text2Map=new HashMap<>();
        for(int i=0;i<text1.length();i++){
            Integer temp1=new Integer(text1.charAt(i));
            if(text1Map.get(temp1)==null) text1Map.put(temp1,1);
            else text1Map.put(temp1,text1Map.get(temp1)+1);
            ASII.add(temp1);
        }
        for(int j=0;j<text2.length();j++){
            Integer temp2=new Integer(text2.charAt(j));
            if(text2Map.get(temp2)==null) text2Map.put(temp2,1);
            else text2Map.put(temp2,text2Map.get(temp2)+1);
            ASII.add(temp2);
        }
        double xy=0.0;
        double x=0.0;
        double y=0.0;
        //计算
        for (Integer it : ASII) {
            Integer t1=text1Map.get(it)==null?0:text1Map.get(it);
            Integer t2=text2Map.get(it)==null?0:text2Map.get(it);
            xy+=t1*t2;
            x+=Math.pow(t1, 2);
            y+=Math.pow(t2, 2);
        }
        if(x==0.0||y==0.0) return 0.0;
        return xy/Math.sqrt(x*y);
    }


    /**
     * 相同返回1,不同返回0
     * @param o1
     * @param o2
     * @return
     */
    public static double equal(Object o1,Object o2) {
        return (o1!=null && o2!=null)&&o1.equals(o2)?1:0;
    }

    /**
     * 值约接近,返回值越接近1
     * 算法为 1-(大-小)/(最大-最小)
     * @param o1
     * @param o2
     * @return
     */
    public static double similarByNumber(int o1, int o2, int max) {
        return 1-Math.abs(o1-o2)/max;
    }
}

算法大致如下:

1、先从excel获取数据

2、用两个for循环计算物品间的相似度

3、排序后取前10个最大的

4、保存数据

第一次跑,以工作人员的自我介绍作为相似度判断依据

        //数据结构

        //结果
        Map<Long,List<Node>> map = new HashMap<>();

        //结果的每一行
        Map<Long, Double> row = new HashMap<>();

        //文件内容
        Map<Long, String> content = new HashMap<>();

        //读取文件
        File file = new File("d:/data7.xls");
        InputStream inputStream = new FileInputStream(file);
        Workbook workbook = ExcelUtil.getWorkbok(inputStream,file);
        Sheet sheet = workbook.getSheetAt(0);

        //跳过第一个
        for (int i = 1; i < sheet.getLastRowNum(); i++) {
            Row r = sheet.getRow(i);
            Cell id = r.getCell(9);
            Cell cont = r.getCell(5);
            content.put(Long.valueOf(id.getStringCellValue()), cont.getStringCellValue());
        }
//        System.out.println(content);

        //两个for循环计算相似度,取前10个
        for (Map.Entry<Long,String> c1:content.entrySet()) {
            Map<Long, Double> m = new HashMap<>();
            for (Map.Entry<Long,String> c2:content.entrySet()) {
                if(c1.getKey().equals(c2.getKey())) continue;
                double r = ScoreCos.similarScoreCos(c1.getValue(), c2.getValue());
                m.put(c2.getKey(), r);
            }
            List<Map.Entry<Long,Double>> list = new ArrayList<Map.Entry<Long,Double>>(m.entrySet());
            Collections.sort(list,new MyComparator());
            List<Node> nodeList = new ArrayList<>();
            for(int i = 0; i< 10 && i < list.size(); i++){
                Map.Entry<Long, Double> entry = list.get(i);
                nodeList.add(new Node(entry.getKey(), entry.getValue()));
            }
            map.put(c1.getKey(), nodeList);
            log.info("key:{},value:{}",c1.getKey(),nodeList);
        }


        //保存为文件
        save(map);

结果跑了4个小时左右,数据大概有30000个。

推测大概有如下原因:

1、单线程

2、只要取前10个,用不着全排序

将单线程变成多线程有多种方法。其中较为简便的可以用Java1.8提供的并行流处理(parallelStream)

同时,从多个方面进行判断

    /**
     * 值越接近1表示越接近
     * @param o
     * @return
     */
    public double distinct(Worker o){
        double dis = 0;
        dis += (3d / 16) * equal(this.expectedCity, o.expectedCity);
        dis += (1d / 16) * equal(this.status, o.status) ;
        dis += (2d / 16) * similarByNumber(this.experience,o.experience,496);
        dis += (2d / 16) * equal(this.education, o.education);
        dis += (2d / 16) * equal(this.constellation, o.constellation);
        dis += (2d / 16) * similarByNumber(this.age,o.age,40);
        dis += (2d / 16) * equal(this.nativePlace, o.nativePlace);
        dis += (1d / 16) * similarScoreCos(this.introduction, o.introduction);
        dis += (1d / 16) * similarScoreCos(this.location, o.location);
        return dis;
    }

改进后:

    /**
     * 流处理
     * @throws IOException
     */
    private static void useStreamApi() throws IOException {
        List<Worker> data = getFromDB();

        Map<Long, List<Node>> map = new ConcurrentHashMap<>();
        AtomicInteger integer = new AtomicInteger();
        //并发执行
        data.parallelStream().forEach(x->{
            //相当于两个for循环
            List<Node> nodes = data.stream()
                //如果userId相同,则置为0
                .map(y -> new Node(y.getUserId(), x.getUserId()==y.getUserId()?0:x.distinct(y)))
                //降序
                .sorted(Comparator.reverseOrder())
                //取前10个
                .limit(10)
                //.peek(System.out::println)
                .collect(Collectors.toList());
            map.put(x.getUserId(), nodes);
            //每隔100个输出一次
            if(integer.getAndIncrement()%100==0)
                log.info("key:{} value:{}",x.getUserId(),nodes);
        });
        save(map);
    }

重新计算一遍后用了40分钟左右便出来了,而且stream用的也很简洁。

 

参考:

字符串余弦相似度的java简单实现

《写给大忙人看的Java SE 8》第二章