Skip to content

Commit fe779b3

Browse files
committed
day two
1 parent 328432f commit fe779b3

File tree

8 files changed

+193
-61
lines changed

8 files changed

+193
-61
lines changed

src/rulesets/Base/base.jl

+9-21
Original file line numberDiff line numberDiff line change
@@ -249,34 +249,22 @@ end
249249
#####
250250

251251
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray) where {F}
252-
y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # could be broadcast, but Yota likes this one
253-
return Broadcast.materialize(y), back
252+
# y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # Yota likes this one
253+
# return Broadcast.materialize(y), back
254+
y, back = rrule_via_ad(cfg, broadcast, f, x) # but testing like this one
255+
return y, back
254256
end
255257

256-
# Could accept Any?
257-
# `_unmap_pad` is also used for `zip`
258258
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F}
259+
@debug "rrule(map, f, arrays...)" f
259260
z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...)
260-
# z, backs = unzip(map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...))
261-
function map_pullback(dz)
262-
df, dxy... = unzip_map(|>, unthunk(dz), backs)
263-
# df, dxy... = unzip(map(|>, unthunk(dz), backs))
264-
return (NoTangent(), ProjectTo(sum(df)), map(_unmap_pad, (x, ys...), dxy)...)
261+
function map_pullback_2(dz)
262+
df, dxy... = unzip_map_reversed(|>, unthunk(dz), backs)
263+
return (NoTangent(), ProjectTo(f)(sum(df)), map(_unmap_pad, (x, ys...), dxy)...)
265264
end
266-
z, map_pullback
265+
z, map_pullback_2
267266
end
268267

269-
# function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F}
270-
# z, zip_back = rrule(zip, x, ys...)
271-
# m, map_back = rrule(config, map, Splat(f), z) # maybe this is inefficient?
272-
# function map_pullback(dm)
273-
# _, dsplatf, dz = map_back(dm)
274-
# _, dxys... = zip_back(dz)
275-
# return (NoTangent(), 0, dxys...)
276-
# end
277-
# return m, map_back
278-
# end
279-
280268
#####
281269
##### `task_local_storage`
282270
#####

src/rulesets/Base/broadcast.jl

+22-3
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,35 @@ function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F}
119119
end
120120

121121
# Path 4: The most generic, save all the pullbacks. Can be 1000x slower.
122-
# Since broadcast makes no guarantee about order of calls, and un-fusing
123-
# can change the number of calls, don't bother to try to reverse the iteration.
122+
# While broadcast makes no guarantee about order of calls, it's cheap to reverse the iteration.
123+
124+
#=
125+
126+
julia> Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), [1,2,3.0])
127+
┌ Debug: split broadcasting generic
128+
│ f = #69 (generic function with 1 method)
129+
│ N = 1
130+
└ @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:126
131+
(14.0, (ZeroTangent(), [2.0, 4.0, 6.0]))
132+
133+
julia> ENV["JULIA_DEBUG"] = nothing
134+
135+
julia> @btime Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), $(rand(1000)));
136+
min 1.321 ms, mean 1.434 ms (23010 allocations, 594.66 KiB) # with unzip_map, as before
137+
min 1.279 ms, mean 1.393 ms (23029 allocations, 595.73 KiB) # with unzip_map_reversed
138+
139+
julia> @btime Yota.grad(xs -> sum(abs2, abs.(xs)), $(randn(1000))); # Debug: split broadcasting derivative
140+
min 2.144 μs, mean 6.620 μs (6 allocations, 23.88 KiB)
141+
142+
=#
124143

125144
function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
126145
@debug("split broadcasting generic", f, N)
127146
ys3, backs = unzip_broadcast(args...) do a...
128147
rrule_via_ad(cfg, f, a...)
129148
end
130149
function back_generic(dys)
131-
deltas = unzip_broadcast(backs, unthunk(dys)) do back, dy # (could be map, sizes match)
150+
deltas = unzip_map_reversed(backs, unthunk(dys)) do back, dy
132151
map(unthunk, back(dy))
133152
end
134153
dargs = map(unbroadcast, args, Base.tail(deltas))

src/rulesets/Base/iterators.jl

+14-27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor
1+
tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/JuliaDiff/Diffractor.jl/pull/86
22

33
#####
44
##### Comprehension: Iterators.map
@@ -7,38 +7,18 @@ tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor
77
# Comprehension does guarantee iteration order. Thus its gradient must reverse.
88

99
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator}
10-
# ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter)
11-
ys, backs = unzip(map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter))
10+
@debug "collect generator"
11+
ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter)
1212
proj_f = ProjectTo(gen.f)
1313
proj_iter = ProjectTo(gen.iter)
1414
function generator_pullback(dys_raw)
1515
dys = unthunk(dys_raw)
16-
# dfs, dxs = unzip_map(|>, Iterators.reverse(dys), Iterators.reverse(backs))
17-
dfs, dxs = unzip(map(|>, Iterators.reverse(dys), Iterators.reverse(backs)))
18-
return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(reverse!!(dxs))))
16+
dfs, dxs = unzip_map_reversed(|>, dys, backs)
17+
return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(dxs)))
1918
end
2019
ys, generator_pullback
2120
end
2221

23-
"""
24-
reverse!!(x)
25-
26-
Reverses `x` in-place if possible, according to `ChainRulesCore.is_inplaceable_destination`.
27-
Only safe if you are quite sure nothing else closes over `x`.
28-
"""
29-
function reverse!!(x::AbstractArray)
30-
if ChainRulesCore.is_inplaceable_destination(x)
31-
Base.reverse!(x)
32-
else
33-
Base.reverse(x)
34-
end
35-
end
36-
frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot)
37-
function rrule(::typeof(reverse!!), x::AbstractArray)
38-
reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy)))
39-
return reverse!!(x), reverse!!_back
40-
end
41-
4222
# Needed for Yota, but shouldn't these be automatic?
4323
ChainRulesCore.rrule(::Type{<:Base.Generator}, f, iter) = Base.Generator(f, iter), dy -> (NoTangent(), dy.f, dy.iter)
4424
ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.ProductIterator(iters), dy -> (NoTangent(), dy.iterators)
@@ -107,18 +87,25 @@ function rrule(::typeof(zip), xs::AbstractArray...)
10787
end
10888

10989
_tangent_unzip(xs::AbstractArray{Tangent{T,B}}) where {T<:Tuple, B<:Tuple} = unzip(reinterpret(B, xs))
110-
_tangent_unzip(xs::AbstractArray) = unzip(xs) # Diffractor
90+
_tangent_unzip(xs::AbstractArray) = unzip(xs) # temp fix for Diffractor
11191

92+
# This is like unbroadcast, except for map's stopping-short behaviour, not broadcast's extension.
93+
# Closing over `x` lets us re-use ∇getindex.
11294
function _unmap_pad(x::AbstractArray, dx::AbstractArray)
11395
if length(x) == length(dx)
11496
ProjectTo(x)(reshape(dx, axes(x)))
11597
else
98+
@debug "_unmap_pad is extending gradient" length(x) == length(dx)
11699
i1 = firstindex(x)
117100
∇getindex(x, vec(dx), i1:i1+length(dx)-1)
118101
# dx2 = vcat(vec(dx), similar(x, ZeroTangent, length(x) - length(dx)))
119102
# ProjectTo(x)(reshape(dx2, axes(x)))
120103
end
121104
end
122105

123-
106+
# For testing
107+
function rrule(::ComposedFunction{typeof(collect), typeof(zip)}, xs::AbstractArray...)
108+
y, back = rrule(zip, xs...)
109+
return collect(y), back
110+
end
124111

src/unzipped.jl

+59
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,65 @@ function unzip_map(f::F, args...) where {F}
8585
return StructArrays.components(StructArray(Iterators.map(f, args...)))
8686
end
8787

88+
unzip_map(f::F, args::Tuple...) where {F} = unzip(map(f, args...))
89+
90+
unzip_map(f::F, args::AbstractGPUArray...) where {F} = unzip(map(f, args...))
91+
92+
function unzip_map_reversed(f::F, args...) where {F}
93+
T = Broadcast.combine_eltypes(f, args)
94+
if isconcretetype(T)
95+
T <: Tuple || throw(ArgumentError("""unzip_map_reversed(f, args) only works on functions returning a tuple,
96+
but f = $(sprint(show, f)) returns type T = $T"""))
97+
end
98+
len1 = length(first(args))
99+
if all(a -> length(a)==len1, args)
100+
rev_args = map(Iterators.reverse, args)
101+
outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...)))
102+
else
103+
len = minimum(length, args)
104+
rev_args = map(a -> Iterators.reverse(@view a[begin:begin+len-1]), args)
105+
outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...)))
106+
end
107+
return map(reverse!!, outs)
108+
end
109+
110+
function unzip_map_reversed(f::F, args::Tuple...) where {F}
111+
len = minimum(length, args)
112+
rev_args = map(a -> reverse(a[1:len]), args)
113+
# vlen = Val(len)
114+
# rev_args = map(args) do a
115+
# reverse(ntuple(i -> a[i], vlen)) # does not infer better
116+
# end
117+
return map(reverse, unzip(map(f, rev_args...)))
118+
end
119+
# function unzip_map_reversed(f::F, args::Tuple{Vararg{Any, N}}...) where {F,N}
120+
# rev_args = map(reverse, args)
121+
# return map(reverse, unzip(map(f, rev_args...)))
122+
# end
123+
124+
"""
125+
reverse!!(x)
126+
127+
Reverses `x` in-place if possible, according to `ChainRulesCore.is_inplaceable_destination`.
128+
Only safe if you are quite sure nothing else closes over `x`.
129+
"""
130+
function reverse!!(x::AbstractArray)
131+
if ChainRulesCore.is_inplaceable_destination(x)
132+
Base.reverse!(x)
133+
else
134+
Base.reverse(x)
135+
end
136+
end
137+
reverse!!(x::AbstractArray{<:AbstractZero}) = x
138+
139+
frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot)
140+
141+
function rrule(::typeof(reverse!!), x::AbstractArray)
142+
reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy)))
143+
return reverse!!(x), reverse!!_back
144+
end
145+
146+
88147
#####
89148
##### unzip
90149
#####

test/rulesets/Base/base.jl

+34
Original file line numberDiff line numberDiff line change
@@ -229,4 +229,38 @@
229229
test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false)
230230
end
231231
end
232+
233+
@testset "map(f, ::Array)" begin
234+
test_rrule(map, identity, [1.0, 2.0], check_inferred=false)
235+
test_rrule(map, conj, [1, 2+im, 3.0]', check_inferred=false)
236+
test_rrule(map, make_two_vec, [4.0, 5.0 + 6im], check_inferred=false)
237+
# @test rrule(CFG, map, make_two_vec, [4.0, 5.0 + 6im])[2]([1:2, 3:4])[3] ≈ [1 + 2im, 3 + 4im] # FiniteDifferences DimensionMismatch
238+
239+
@test_skip test_rrule(map, Multiplier(rand() + im), rand(3), check_inferred=false)
240+
rrule(CFG, map, Multiplier(2.0), [3, 4, 5.0])[2]([10, 20, 30]) # (NoTangent(), Multiplier{Float64}(259.99999), [19.99999, 40.000, 60.000]) -- WTF?
241+
@test_skip test_rrule(map, Multiplier(rand() + im) NoTangent(), rand(3), check_inferred=false) # Expression: ad_cotangent isa NoTangent Evaluated: Multiplier{ComplexF64}(-3.7869064372333963 + 2.046139872866103im) isa NoTangent
242+
243+
y1, bk1 = rrule(CFG, map, abs2, [1.0, 2.0, 3.0])
244+
@test y1 == [1, 4, 9]
245+
@test bk1([4, 5, 6.0])[3] 2 .* (1:3) .* (4:6)
246+
247+
y2, bk2 = rrule(CFG, map, Counter(), [11, 12, 13.0])
248+
@test y2 == map(Counter(), 11:13)
249+
@test_skip bk2(ones(3))[3] == [93, 83, 73] # FiniteDifferences has incremented the counter very high
250+
end
251+
252+
@testset "map(f, ::Array, ::Array)" begin
253+
test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent}
254+
test_rrule(map, /, [1.0, 2.0], [3.0, 4.0, 5.0], check_inferred=false)
255+
test_rrule(map, atan, [1, 2, 3.0], [4 5; 6 7.0], check_inferred=false)
256+
257+
test_rrule(map, Multiplier(rand()), rand(3), rand(4), check_inferred=false)
258+
259+
cnt3 = Counter()
260+
y3, bk3 = rrule(CFG, map, cnt3, [1, 2, 3.0], [0, -1, -2, -33.3])
261+
@test y3 == 1:3
262+
@test cnt3 == Counter(3)
263+
z3 = bk3([1, 1, 1000])
264+
@test z3[3] == [53, 33, 13000]
265+
end
232266
end

test/rulesets/Base/iterators.jl

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
2+
@testset "Comprehension" begin
3+
@testset "simple" begin
4+
y1, bk1 = rrule(CFG, collect, (i^2 for i in [1.0, 2.0, 3.0]))
5+
@test y1 == [1,4,9]
6+
t1 = bk1(4:6)[2]
7+
@test t1 isa Tangent{<:Base.Generator}
8+
@test t1.f == NoTangent()
9+
@test t1.iter 2 .* (1:3) .* (4:6)
10+
11+
y2, bk2 = rrule(CFG, collect, Iterators.map(Counter(), [11, 12, 13.0]))
12+
@test y2 == map(Counter(), 11:13)
13+
@test bk2(ones(3))[2].iter == [93, 83, 73]
14+
end
15+
end
16+
17+
@testset "Iterators" begin
18+
@testset "zip" begin
19+
test_rrule(collectzip, rand(3), rand(3))
20+
test_rrule(collectzip, rand(2,2), rand(2,2), rand(2,2))
21+
test_rrule(collectzip, rand(4), rand(2,2))
22+
23+
test_rrule(collectzip, rand(3), rand(5))
24+
test_rrule(collectzip, rand(3,2), rand(5))
25+
end
26+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ end
6161
include_test("rulesets/Base/mapreduce.jl")
6262
include_test("rulesets/Base/sort.jl")
6363
include_test("rulesets/Base/broadcast.jl")
64+
include_test("rulesets/Base/iterators.jl")
6465

6566
include_test("unzipped.jl") # used primarily for broadcast
6667

test/unzipped.jl

+28-10
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11

2-
using ChainRules: unzip_broadcast, unzip #, unzip_map
2+
using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed
33

44
@testset "unzipped.jl" begin
5-
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzipmap, unzipbroadcast] # unzip_map,
5+
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzipmap, unzipbroadcast, unzip_map, unzip_map_reversed]
66
@test_throws Exception fun(sqrt, 1:3)
77

8-
@test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6])
9-
@test fun(tuple, [1, 10, 100]) == ([1, 10, 100],)
10-
@test fun(tuple, 1:3, fill(nothing, 3)) == (1:3, fill(nothing, 3))
11-
@test fun(tuple, [1, 10, 100], fill(nothing, 3)) == ([1, 10, 100], fill(nothing, 3))
12-
@test fun(tuple, fill(nothing, 3), fill(nothing, 3)) == (fill(nothing, 3), fill(nothing, 3))
8+
@test @inferred(fun(tuple, 1:3, 4:6)) == ([1, 2, 3], [4, 5, 6])
9+
@test @inferred(fun(tuple, [1, 10, 100])) == ([1, 10, 100],)
10+
@test @inferred(fun(tuple, 1:3, fill(nothing, 3))) == (1:3, fill(nothing, 3))
11+
@test @inferred(fun(tuple, [1, 10, 100], fill(nothing, 3))) == ([1, 10, 100], fill(nothing, 3))
12+
@test @inferred(fun(tuple, fill(nothing, 3), fill(nothing, 3))) == (fill(nothing, 3), fill(nothing, 3))
1313

1414
if contains(string(fun), "map")
15-
@test fun(tuple, 1:3, 4:999) == ([1, 2, 3], [4, 5, 6])
15+
@test @inferred(fun(tuple, 1:3, 4:999)) == ([1, 2, 3], [4, 5, 6])
1616
else
17-
@test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5])
17+
@test @inferred(fun(tuple, [1,2,3], [4 5])) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5])
18+
@test @inferred(fun(tuple, [1,2,3], 6)) == ([1, 2, 3], [6, 6, 6])
1819
end
1920

2021
if contains(string(fun), "map")
@@ -24,7 +25,24 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map
2425
@test fun(tuple, (1,2,3), (7,)) == ((1, 2, 3), (7, 7, 7))
2526
@test fun(tuple, (1,2,3), 8) == ((1, 2, 3), (8, 8, 8))
2627
end
27-
@test fun(tuple, (1,2,3), [4,5,6]) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector
28+
@test @inferred(fun(tuple, (1,2,3), [4,5,6])) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector
29+
end
30+
31+
@testset "zip behaviour: $unzip_map" for unzip_map in [unzip_map, unzip_map_reversed]
32+
check(f, args...) = @inferred(unzip_map(f, args...)) == unzip(map(f, args...))
33+
@test check(tuple, [1 2; 3 4], [5,6,7,8]) # makes a vector
34+
@test check(tuple, [1 2; 3 4], [5,6,7])
35+
@test check(tuple, [1 2; 3 4], [5,6,7,8,9,10])
36+
end
37+
38+
@testset "unzip_map_reversed" begin
39+
cnt(x, y) = (x, y) .+ (CNT[] += 1)
40+
CNT = Ref(0)
41+
@test unzip_map_reversed(cnt, [10, 20], [30, 40, 50]) == ([12, 21], [32, 41])
42+
@test CNT[] == 2
43+
44+
CNT = Ref(0)
45+
@test unzip_map_reversed(cnt, (10, 20, 99), (30, 40)) == ((12, 21), (32, 41))
2846
end
2947

3048
@testset "rrules" begin

0 commit comments

Comments
 (0)