2016-07-30 22 views
2

我在clojure和scala中编写了一个编辑距离算法。使用java原始数组的clojure代码比scala版本慢70倍

scala版本比clojure版本运行速度快70倍。

的Clojure:

(defn edit-distance                                                                
    "['seq of char' 'seq of char']"                                                            
    [s0 s1]                                                                  
    (let [n0 (count s0)                                                               
     n1 (count s1)                                                               
     distances (make-array Long/TYPE (inc n0) (inc n1))]                                                      
    ;;initialize distances                                                              
    (doseq [i (range 1 (inc n0))] (aset-long distances i 0 i))                                                     
    (doseq [j (range 1 (inc n1))] (aset-long distances 0 j j))                                                     

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]                                                       
     (let [ins (aget distances i (dec j))                                                          
      del (aget distances (dec i) j)                                                          
      match (aget distances (dec i) (dec j))                                                        
      min-dist (min ins del match)]                                                          
     (cond                                                                 
      (not= match min-dist) (aset-long distances i j (inc min-dist))                                                   
      (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset-long distances i j (inc min-dist))                                              
      :else (aset-long distances i j min-dist))))                                                       
    (aget distances n0 n1)))  

阶:

def editDistance(s0: Array[Char], s1: Array[Char]):Int = {                                                     
     val n0 = s0.length                                                               
     val n1 = s1.length                                                               
     val distances = Array.fill(n0+1)(ArrayBuffer.fill(n1+1)(0))                                                    
     for(j <- 0 to n1){distances(0)(j) = j}                                                          
     for(i <- 0 to n0){distances(i)(0) = i}                                                          
     for(i <- 1 to n0; j <- 1 to n1){                                                           
     val ins = distances(i)(j-1)                                                            
     val del = distances(i-1)(j)                                                            
     val matches = distances(i-1)(j-1)                                                          
     val minDist = (ins::del::matches::Nil).reduceLeft(_ min _)                                                    
     if (matches != minDist)                                                             
      distances(i)(j) = minDist + 1                                                          
     else if (s0(i-1) == s1(j-1))                                                           
      distances(i)(j) = minDist                                                           
     else                                                                 
      distances(i)(j) = minDist + 1                                                          
     }                                                                   
     distances(n0)(n1)                                                               
    }         

我使用Clojure中的Java的阵列,以获得最佳的性能。我已经考虑暗示每当调用aget,但我的代码执行更糟糕(这可能是因为make-array已经定义了一个类型数组)。我也在projects.clj中覆盖clojure :jvm-opts。然而,我得到的性能差距是70倍。

我在clojure中使用java数组有什么问题?

感谢您的洞察力。

+1

你有没有通过运行这个**探查**?尤其要注意内存消耗。 –

+0

@ Anony-Mousse确实通过_java.lang.reflect.method_消耗> 90%的内存做了反思。考虑到2d数组输入的距离,怎么会发生这种情况? – user3639782

+0

也许一些lambda表达式。 clojure是否使用方法引用生成Java 8字节码? –

回答

4

我想我找出问题所在。

正如您在评论中提到的,反射调用大部分时间都会消耗。这是为什么。

分析我已经设置*warn-on-reflection*为true代码:之前

(set! *warn-on-reflection* true) 

然后,如果你看看asetmacro产生aset-long功能的源代码,你会看到,对于4+ arities它使用apply来调用这些函数。 aget同样适用于3个以上的城市。我不是100%确定的,但我相信关于参数类型的信息在函数中会丢失。另外,如果仔细查看herehere,您可能会注意到agetaset函数可以在编译期间内联。我们绝对想要:

(defn edit-distance 
    "['seq of char' 'seq of char']" 
    [s0 s1] 
    (let [n0 (count s0) 
     n1 (count s1) 
     distances (make-array Long/TYPE (inc n0) (inc n1))] 
    ;; I've unwinded all aget/aset calls, so they can be inlined by compiler. 
    ;; Also I'm type hinting first argument of toplevel aget/aset calls. 
    ;; The reason is explained next. 
    (doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i)) 
    (doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j)) 

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))] 
     (let [ins (aget ^longs (aget distances i) (dec j)) 
      del (aget ^longs (aget distances (dec i)) j) 
      match (aget ^longs (aget distances (dec i)) (dec j)) 
      min-dist (min ins del match)] 
     (cond 
      (not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist)) 
      (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist)) 
      :else (aset ^longs (aget distances i) j min-dist)))) 
    ;; we can leave this, since it is not placed within loop 
    (aget distances n0 n1))) 

让我们来编译我们的新函数。请记住我们在开始时设置的全局变量?如果设置为true,编译器在编译过程中产生一堆警告:

Reflection warning, core.clj:75:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int). 
Reflection warning, core.clj:76:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int). 
Reflection warning, core.clj:77:25 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int). 
... 

的问题是,Clojure的想不通的类型(make-array Long/TYPE (inc n0) (inc n1)),将其标记为unknown。我们需要键入提示它:

(let [... 
     ;; type hint for 2d array of primitive longs 
     ^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1)) 
     ...] 
    ...) 

在这一点上,似乎我们都设置了。最终版本如下:

(defn edit-distance 
    "['seq of char' 'seq of char']" 
    [s0 s1] 
    (let [n0 (count s0) 
     n1 (count s1) 
     ^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))] 
    ;;initialize distances 
    (doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i)) 
    (doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j)) 

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))] 
     (let [ins (aget ^longs (aget distances i) (dec j)) 
      del (aget ^longs (aget distances (dec i)) j) 
      match (aget ^longs (aget distances (dec i)) (dec j)) 
      min-dist (min ins del match)] 
     (cond 
      (not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist)) 
      (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist)) 
      :else (aset ^longs (aget distances i) j min-dist)))) 
    (aget distances n0 n1))) 

这里是基准:

前:

> (time (edit-distance i1 i2)) 
"Elapsed time: 4601.025555 msecs" 
291 

后:

> (time (edit-distance i1 i2)) 
"Elapsed time: 27.782828 msecs" 
291 
+0

谢谢,现在我也学习了如何检查源代码是否有帮助。 – user3639782