1
- tup2 (x) = Tuple {Any,Any} (x) # temp fix for Diffractor, https://github.com/JuliaDiff/Diffractor.jl/pull/86
2
-
3
1
# ####
4
2
# #### Comprehension: Iterators.map
5
3
# ####
@@ -8,7 +6,7 @@ tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/Julia
8
6
9
7
function rrule (cfg:: RuleConfig{>:HasReverseMode} , :: typeof (collect), gen:: G ) where {G<: Base.Generator }
10
8
@debug " collect generator"
11
- ys, backs = unzip_map (x -> rrule_via_ad (cfg, gen. f, x)|> tup2 , gen. iter)
9
+ ys, backs = unzip_map (x -> rrule_via_ad (cfg, gen. f, x), gen. iter)
12
10
proj_f = ProjectTo (gen. f)
13
11
proj_iter = ProjectTo (gen. iter)
14
12
function generator_pullback (dys_raw)
@@ -28,8 +26,8 @@ ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.Pro
28
26
Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
29
27
Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
30
28
31
- Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5])
32
- Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5])
29
+ Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: all field arrays must have same shape
30
+ Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: type Array has no field iterators
33
31
34
32
Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3)
35
33
Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally
@@ -44,11 +42,10 @@ Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3
44
42
@btime Yota.grad($(rand(1000))) do xs
45
43
sum(abs2, [sqrt(x) for x in xs])
46
44
end
47
- # Yota min 1.134 ms, mean 1.207 ms (22017 allocations, 548.50 KiB)
48
- # Diffractor min 936.708 μs, mean 1.020 ms (18028 allocations, 611.25 KiB)
49
- # without unzip_map min 734.292 μs, mean 810.341 μs (13063 allocations, 517.97 KiB)
45
+ # Yota min 759.000 μs, mean 800.754 μs (22041 allocations, 549.62 KiB)
46
+ # Diffractor min 559.000 μs, mean 622.464 μs (18051 allocations, 612.34 KiB)
50
47
51
- # Zygote min 6.117 μs, mean 11.287 μs (24 allocations, 40.31 KiB)
48
+ # Zygote min 3.198 μs, mean 6.849 μs (20 allocations, 40.11 KiB)
52
49
53
50
54
51
@btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys
57
54
end
58
55
sum(abs2, zs)
59
56
end
60
- # Yota + CR: min 2.643 ms, mean 2.781 ms (35011 allocations, 915.19 KiB)
61
- # Diffractor + CR: min 1.184 ms, mean 1.285 ms (23026 allocations, 775.09 KiB)
62
- # without unzip_map min 947.084 μs, mean 1.036 ms (18062 allocations, 697.86 KiB)
57
+ # Yota + CR: min 1.598 ms, mean 1.691 ms (38030 allocations, 978.75 KiB)
58
+ # Diffractor + CR: min 767.250 μs, mean 847.640 μs (26045 allocations, 838.66 KiB)
63
59
64
- # Zygote: min 21.291 μs, mean 36.456 μs (26 allocations, 79.59 KiB)
60
+ # Zygote: min 13.417 μs, mean 22.896 μs (26 allocations, 79.59 KiB) -- 100x faster
65
61
66
62
67
63
=#
0 commit comments