-
Notifications
You must be signed in to change notification settings - Fork 66
Expand file tree
/
Copy pathdifferentials.jl
More file actions
118 lines (104 loc) · 3.8 KB
/
differentials.jl
File metadata and controls
118 lines (104 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
@testset "Differentials" begin
@testset "Wirtinger" begin
w = Wirtinger(1+1im, 2+2im)
@test wirtinger_primal(w) == 1+1im
@test wirtinger_conjugate(w) == 2+2im
@test w + w == Wirtinger(2+2im, 4+4im)
@test w + One() == w + 1 == w + Thunk(()->1) == Wirtinger(2+1im, 2+2im)
@test w * One() == One() * w == w
@test w * 2 == 2 * w == Wirtinger(2 + 2im, 4 + 4im)
# TODO: other + methods stack overflow
@test_throws ErrorException w*w
@test_throws ArgumentError extern(w)
for x in w
@test x === w
end
@test broadcastable(w) == w
@test_throws MethodError conj(w)
end
@testset "Zero" begin
z = Zero()
@test extern(z) === false
@test z + z == z
@test z + 1 == 1
@test 1 + z == 1
@test z * z == z
@test z * 1 == z
@test 1 * z == z
for x in z
@test x === z
end
@test broadcastable(z) isa Ref{Zero}
@test conj(z) == z
end
@testset "One" begin
o = One()
@test extern(o) === true
@test o + o == 2
@test o + 1 == 2
@test 1 + o == 2
@test o * o == o
@test o * 1 == 1
@test 1 * o == 1
for x in o
@test x === o
end
@test broadcastable(o) isa Ref{One}
@test conj(o) == o
end
@testset "Thunk" begin
@test @thunk(3) isa Thunk
@testset "show" begin
rep = repr(Thunk(rand))
@test occursin(r"Thunk\(.*rand.*\)", rep)
end
@testset "Externing" begin
@test extern(@thunk(3)) == 3
@test extern(@thunk(@thunk(3))) == 3
end
@testset "unthunk" begin
@test unthunk(@thunk(3)) == 3
@test unthunk(@thunk(@thunk(3))) isa Thunk
end
@testset "calling thunks should call inner function" begin
@test (@thunk(3))() == 3
@test (@thunk(@thunk(3)))() isa Thunk
end
@testset "erroring thunks should include the source in the backtrack" begin
expected_line = (@__LINE__) + 2 # for testing it is at right palce
try
x = @thunk(error())
extern(x)
catch err
err isa ErrorException || rethrow()
st = stacktrace(catch_backtrace())
# Should be 2nd last line, as last line will be the `error` function
stackframe = st[2]
@test stackframe.line == expected_line
@test stackframe.file == Symbol(@__FILE__)
end
end
end
@testset "No ambiguities in $f" for f in (+, *)
# We don't use `Test.detect_ambiguities` as we are only interested in
# the +, and * operations. We also would catch any that are unrelated
# to this package. but that is not a problem. Since no such failings
# occur in our dependencies.
ambig_methods = [
(m1, m2) for m1 in methods(f), m2 in methods(f) if Base.isambiguous(m1, m2)
]
@test isempty(ambig_methods)
end
@testset "Refine Differential" begin
@test refine_differential(typeof(1.0 + 1im), Wirtinger(2,2)) == Wirtinger(2,2)
@test refine_differential(typeof([1.0 + 1im]), Wirtinger(2,2)) == Wirtinger(2,2)
@test refine_differential(typeof(1.2), Wirtinger(2,2)) == 4
@test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4
# For most differentials, in most domains, this does nothing
for der in (DoesNotExist(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0)
for 𝒟 in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2]))
@test refine_differential(𝒟, der) === der
end
end
end
end