2016-12-26 35 views
1

我试图编译从Numeric.AD以下小例子:最小Numeric.AD例如不会编译

import Numeric.AD 
timeAndGrad f l = grad f l 
main = putStrLn "hi" 

,我碰到这个错误:

test.hs:3:24: 
    Couldn't match expected type ‘f (Numeric.AD.Internal.Reverse.Reverse 
             s a) 
            -> Numeric.AD.Internal.Reverse.Reverse s a’ 
       with actual type ‘t’ 
     because type variable ‘s’ would escape its scope 
    This (rigid, skolem) type variable is bound by 
     a type expected by the context: 
     Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape => 
     f (Numeric.AD.Internal.Reverse.Reverse s a) 
     -> Numeric.AD.Internal.Reverse.Reverse s a 
     at test.hs:3:19-26 
    Relevant bindings include 
     l :: f a (bound at test.hs:3:15) 
     f :: t (bound at test.hs:3:13) 
     timeAndGrad :: t -> f a -> f a (bound at test.hs:3:1) 
    In the first argument of ‘grad’, namely ‘f’ 
    In the expression: grad f l 

任何线索为什么会发生这种情况?通过观察前面的例子据我了解,这是“扁平化” grad的类型:

grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a

,但我真的需要做这样的事情在我的代码。事实上,这是不能编译的最简单的例子。我想要做的更复杂的事情是这样的:

example :: SomeType 
example f x args = (do stuff with the gradient and gradient "function") 
    where gradient = grad f x 
      gradientFn = grad f 
      (other where clauses involving gradient and gradient "function") 

这里有一个稍微复杂一点的版本类型签名,它编译。

{-# LANGUAGE RankNTypes #-} 

import Numeric.AD 
import Numeric.AD.Internal.Reverse 

-- compiles but I can't figure out how to use it in code 
grad2 :: (Show a, Num a, Floating a) => (forall s.[Reverse s a] -> Reverse s a) -> [a] -> [a] 
grad2 f l = grad f l 

-- compiles with the right type, but the resulting gradient is all 0s... 
grad2' :: (Show a, Num a, Floating a) => ([a] -> a) -> [a] -> [a] 
grad2' f l = grad f' l 
     where f' = Lift . f . extractAll 
     -- i've tried using the Reverse constructor with Reverse 0 _, Reverse 1 _, and Reverse 2 _, but those don't yield the correct gradient. Not sure how the modes work 

extractAll :: [Reverse t a] -> [a] 
extractAll xs = map extract xs 
      where extract (Lift x) = x -- non-exhaustive pattern match 

dist :: (Show a, Num a, Floating a) => [a] -> a 
dist [x, y] = sqrt(x^2 + y^2) 

-- incorrect output: [0.0, 0.0] 
main = putStrLn $ show $ grad2' dist [1,2] 

但是,我无法弄清楚如何使用第一个版本,grad2,在代码,因为我不知道该如何处理Reverse s a。第二个版本grad2'具有正确的类型,因为我使用内部构造函数Lift来创建Reverse s a,但我不能理解内部(特别是参数s)的工作方式,因为输出渐变全部为0。使用其他构造函数Reverse(此处未显示)也会产生错误的渐变。

或者,有没有人们使用ad代码的库/代码的例子?我认为我的用例是非常普遍的用例。

+2

如果您向timeAndGrad提供类型签名,会发生什么情况?一级方法可能会带来更多运气。 – ocharles

+0

我编辑我的问题,添加一个类型签名和另一种方法(这也不起作用)。 – kye

回答

2

随着where f' = Lift . f . extractAll你基本上创建了一个后门进入自动分化基础类型,抛出所有的派生物,只保留常量值。如果你使用这个为grad,你得到一个零结果并不奇怪!

明智的办法是只使用grad,因为它是:

dist :: Floating a => [a] -> a 
dist [x, y] = sqrt $ x^2 + y^2 
-- preferrable is of course `dist = sqrt . sum . map (^2)` 

main = print $ grad dist [1,2] 
-- output: [0.4472135954999579,0.8944271909999159] 

你并不真的需要知道什么更复杂的使用自动分化。只要你只区分NumFloating多态函数,一切都会按原样运行。如果您需要区分作为参数传入的函数,则需要将该参数设为rank-2多态(可以选择切换到ad函数的rank-1版本,但我敢说这不太优雅,并没有真正获得你的好处)。

{-# LANGUAGE Rank2Types, UnicodeSyntax #-} 

mainWith :: (∀n . Floating n => [n] -> n) -> IO() 
mainWith f = print $ grad f [1,2] 

main = mainWith dist 
+0

是的,我需要区分作为参数传入的函数。你能解释一下更多你的意思吗?“使这个参数排名-2多态?”我也尝试切换到grad的1级版本,并且它要求我指定函数具有类型([Forward a] - > Forward a)。 – kye

+0

这个类型不允许在这里使用类型为(Num a => [a] - > a)的函数,但是我可以在代码中传递它。我不知道为什么类型的行为是这样的。 – kye