Skip to content

Commit 45e2ae7

Browse files
committed
rm tup2, update times
1 parent fe779b3 commit 45e2ae7

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

src/rulesets/Base/base.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ end
257257

258258
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F}
259259
@debug "rrule(map, f, arrays...)" f
260-
z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...)
260+
z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...), x, ys...)
261261
function map_pullback_2(dz)
262262
df, dxy... = unzip_map_reversed(|>, unthunk(dz), backs)
263263
return (NoTangent(), ProjectTo(f)(sum(df)), map(_unmap_pad, (x, ys...), dxy)...)

src/rulesets/Base/iterators.jl

+9-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/JuliaDiff/Diffractor.jl/pull/86
2-
31
#####
42
##### Comprehension: Iterators.map
53
#####
@@ -8,7 +6,7 @@ tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/Julia
86

97
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator}
108
@debug "collect generator"
11-
ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter)
9+
ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x), gen.iter)
1210
proj_f = ProjectTo(gen.f)
1311
proj_iter = ProjectTo(gen.iter)
1412
function generator_pullback(dys_raw)
@@ -28,8 +26,8 @@ ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.Pro
2826
Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
2927
Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
3028
31-
Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5])
32-
Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5])
29+
Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: all field arrays must have same shape
30+
Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: type Array has no field iterators
3331
3432
Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3)
3533
Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally
@@ -44,11 +42,10 @@ Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3
4442
@btime Yota.grad($(rand(1000))) do xs
4543
sum(abs2, [sqrt(x) for x in xs])
4644
end
47-
# Yota min 1.134 ms, mean 1.207 ms (22017 allocations, 548.50 KiB)
48-
# Diffractor min 936.708 μs, mean 1.020 ms (18028 allocations, 611.25 KiB)
49-
# without unzip_map min 734.292 μs, mean 810.341 μs (13063 allocations, 517.97 KiB)
45+
# Yota min 759.000 μs, mean 800.754 μs (22041 allocations, 549.62 KiB)
46+
# Diffractor min 559.000 μs, mean 622.464 μs (18051 allocations, 612.34 KiB)
5047
51-
# Zygote min 6.117 μs, mean 11.287 μs (24 allocations, 40.31 KiB)
48+
# Zygote min 3.198 μs, mean 6.849 μs (20 allocations, 40.11 KiB)
5249
5350
5451
@btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys
@@ -57,11 +54,10 @@ end
5754
end
5855
sum(abs2, zs)
5956
end
60-
# Yota + CR: min 2.643 ms, mean 2.781 ms (35011 allocations, 915.19 KiB)
61-
# Diffractor + CR: min 1.184 ms, mean 1.285 ms (23026 allocations, 775.09 KiB)
62-
# without unzip_map min 947.084 μs, mean 1.036 ms (18062 allocations, 697.86 KiB)
57+
# Yota + CR: min 1.598 ms, mean 1.691 ms (38030 allocations, 978.75 KiB)
58+
# Diffractor + CR: min 767.250 μs, mean 847.640 μs (26045 allocations, 838.66 KiB)
6359
64-
# Zygote: min 21.291 μs, mean 36.456 μs (26 allocations, 79.59 KiB)
60+
# Zygote: min 13.417 μs, mean 22.896 μs (26 allocations, 79.59 KiB) -- 100x faster
6561
6662
6763
=#

0 commit comments

Comments
 (0)