-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathforward.jl
More file actions
117 lines (109 loc) · 4.45 KB
/
forward.jl
File metadata and controls
117 lines (109 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
function fwd_transform(ci, args...)
newci = copy(ci)
fwd_transform!(newci, args...)
return newci
end
function fwd_transform!(ci, mi, nargs, N)
new_code = Any[]
new_codelocs = Any[]
ssa_mapping = Int[]
loc_mapping = Int[]
emit!(@nospecialize stmt) = stmt
function emit!(stmt::Expr)
stmt.head ∈ (:call, :(=), :new, :isdefined) || return stmt
push!(new_code, stmt)
push!(new_codelocs, isempty(new_codelocs) ? 0 : new_codelocs[end])
return SSAValue(length(new_code))
end
function mapstmt!(@nospecialize stmt)
if isexpr(stmt, :(=))
return Expr(stmt.head, emit!(mapstmt!(stmt.args[1])), emit!(mapstmt!(stmt.args[2])))
elseif isexpr(stmt, :call)
args = map(stmt.args) do stmt
emit!(mapstmt!(stmt))
end
return Expr(:call, ∂☆{N}(), args...)
elseif isexpr(stmt, :new)
args = map(stmt.args) do stmt
emit!(mapstmt!(stmt))
end
return Expr(:call, ∂☆new{N}(), args...)
elseif isexpr(stmt, :splatnew)
args = map(stmt.args) do stmt
emit!(mapstmt!(stmt))
end
return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...)
elseif isa(stmt, SSAValue)
return SSAValue(ssa_mapping[stmt.id])
elseif isa(stmt, Core.SlotNumber)
return SlotNumber(2 + stmt.id)
elseif isa(stmt, Argument)
return SlotNumber(2 + stmt.n)
elseif isa(stmt, NewvarNode)
return NewvarNode(SlotNumber(2 + stmt.slot.id))
elseif isa(stmt, ReturnNode)
return ReturnNode(emit!(mapstmt!(stmt.val)))
elseif isa(stmt, GotoNode)
return stmt
elseif isa(stmt, GotoIfNot)
return GotoIfNot(emit!(Expr(:call, primal, emit!(mapstmt!(stmt.cond)))), stmt.dest)
elseif isexpr(stmt, :static_parameter)
return ZeroBundle{N}(mi.sparam_vals[stmt.args[1]::Int])
elseif isexpr(stmt, :foreigncall)
return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?")
elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds) || isexpr(stmt, :loopinfo) ||
isexpr(stmt, :code_coverage_effect)
# Can't trust that meta annotations are still valid in the AD'd
# version.
return nothing
elseif isexpr(stmt, :isdefined)
return Expr(:call, ZeroBundle{N}, emit!(stmt))
# Always disable `@inbounds`, as we don't actually know if the AD'd
# code is truly `@inbounds` or not.
elseif isexpr(stmt, :boundscheck)
return DNEBundle{N}(true)
else
# Fallback case, for literals.
# If it is an Expr, then it is not a literal
if isa(stmt, Expr)
error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt")
end
return Expr(:call, zero_bundle{N}(), stmt)
end
end
meth = mi.def::Method
for i = 1:meth.nargs
if meth.isva && i == meth.nargs
args = map(i:(nargs+1)) do j::Int
emit!(Expr(:call, getfield, SlotNumber(2), j))
end
emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, ∂vararg{N}(), args...)))
else
emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, getfield, SlotNumber(2), i)))
end
end
for (stmt, codeloc) in zip(ci.code, ci.codelocs)
push!(loc_mapping, length(new_code)+1)
push!(new_codelocs, codeloc)
push!(new_code, mapstmt!(stmt))
push!(ssa_mapping, length(new_code))
end
# Rewrite control flow
for (i, stmt) in enumerate(new_code)
if isa(stmt, GotoNode)
new_code[i] = GotoNode(loc_mapping[stmt.label])
elseif isa(stmt, GotoIfNot)
new_code[i] = GotoIfNot(stmt.cond, loc_mapping[stmt.dest])
end
end
ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
ci.slottypes = ci.slottypes === nothing ? nothing : Any[Any, Any, ci.slottypes...]
ci.code = new_code
ci.codelocs = new_codelocs
ci.ssavaluetypes = length(new_code)
ci.ssaflags = UInt8[0 for i=1:length(new_code)]
ci.method_for_inference_limit_heuristics = meth
ci.edges = MethodInstance[mi]
return ci
end