-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathforward.jl
More file actions
128 lines (105 loc) · 3.09 KB
/
forward.jl
File metadata and controls
128 lines (105 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
module forward_tests
using Diffractor
using ChainRules
using ChainRulesCore
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
using LinearAlgebra
using Test
# Minimal 2-nd order forward smoke test
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
# Simple Forward Mode tests
let var"'" = Diffractor.PrimeDerivativeFwd
recursive_sin(x) = sin(x)
ChainRulesCore.frule(∂, ::typeof(recursive_sin), x) = frule(∂, sin, x)
# Integration tests
@test recursive_sin'(1.0) == cos(1.0)
@test recursive_sin''(1.0) == -sin(1.0)
@test recursive_sin'''(1.0) == -cos(1.0)
@test_broken recursive_sin''''(1.0) == sin(1.0)
@test_broken recursive_sin'''''(1.0) == cos(1.0)
@test_broken recursive_sin''''''(1.0) == -sin(1.0)
# Test the special rules for sin/cos/exp
@test sin''''''(1.0) == -sin(1.0)
@test cos''''''(1.0) == -cos(1.0)
@test exp''''''(1.0) == exp(1.0)
@test (x->prod([x, 4]))'(3) == 4
end
# Some Basic Mixed Mode tests
function sin_twice_fwd(x)
let var"'" = Diffractor.PrimeDerivativeFwd
sin''(x)
end
end
let var"'" = Diffractor.PrimeDerivativeFwd
@test sin_twice_fwd'(1.0) == sin'''(1.0)
end
@testset "No partials" begin
primal_calls = Ref(0)
function foo(x, y)
primal_calls[]+=1
return x+y
end
frule_calls = Ref(0)
function ChainRulesCore.frule((_, ẋ, ẏ), ::typeof(foo), x, y)
frule_calls[]+=1
return x+y, ẋ+ẏ
end
# Special case if there is no derivative information at all:
@test (Diffractor.∂☆{1}())(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0)
@test frule_calls[] == 0
@test primal_calls[] == 1
end
@testset "indexing" begin
# Test to make sure that `:boundscheck` and such are properly handled
function foo(x)
t = (x, x)
return t[1] + 1
end
let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(1.0) == 1.0
end
# Test that `@inbounds` is ignored by Diffractor
function foo_errors(x)
t = (x, x)
@inbounds return t[3] + 1
end
let var"'" = Diffractor.PrimeDerivativeFwd
@test_throws BoundsError foo_errors'(1.0) == 1.0
end
end
@testset "structs" begin
struct IDemo
x::Float64
y::Float64
end
function foo(a)
obj = IDemo(2.0, a)
return obj.x * obj.y
end
let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(100.0) == 2.0
@test foo''(100.0) == 0.0
end
end
@testset "tuples" begin
function foo(a)
tup = (2.0, a)
return first(tup) * tup[2]
end
let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(100.0) == 2.0
@test foo''(100.0) == 0.0
end
end
@testset "vararg" begin
function foo(a)
tup = (2.0, a)
return *(tup...)
end
let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(100.0) == 2.0
@test foo''(100.0) == 0.0
end
end
end