Multi-threading with ThreadsX
As you saw in the previous section, Base.Threads does not have a built-in parallel reduction. You can implement it yourself by hand, but all solutions are somewhat awkward, and you can run into problems with thread safety and performance (slow atomic variables, false sharing, etc) if you don’t pay close attention.
Enter ThreadsX, a multi-threaded Julia library that provides parallel versions of some of the Base functions. To see the list of supported functions, use the double-TAB feature inside REPL:
using ThreadsX
ThreadsX.<TAB>
?ThreadsX.mapreduce
?mapreduceAs you see in this example, not all functions in ThreadsX are well-documented, but this is exactly the point: they reproduce the functionality of their Base serial equivalents, so you can always look up help on a corresponding serial function.
Consider this Base function:
mapreduce(x->x^2, +, 1:10) # sum up squares of all integers from 1 to 10This function allows an alternative syntax:
mapreduce(+,1:10) do i
i^2 # plays the role of the function applied to each element
endTo parallelize either snippet, replace mapreduce with ThreadsX.mapreduce, assuming you are running Julia with multiple threads. Do not time this code, as this computation is very fast, and the timing will mostly likely be dominated by an overhead from launching and terminating multiple threads. Instead, let’s parallelize and time the slow series.
Parallelizing the slow series with ThreadsX.mapreduce
Save the following as mapreduce.jl:
using BenchmarkTools, ThreadsX
function digitsin(digitSequence::Int, num)
base = 10
while (digitSequence ÷ base > 0); base *= 10; end
while num > 0; (num % base) == digitSequence && return true; num ÷= 10; end
return false
end
function slow(n::Int64, digitSequence::Int)
total = ThreadsX.mapreduce(+,1:n) do i
if !digitsin(digitSequence, i)
1.0 / i
else
0.0
end
end
return total
end
total = @btime slow(Int64(1e8), 9)
println("total = ", total) # total = 13.277605949855294With 4 CPU cores, I see:
$ julia mapreduce.jl # runtime with 1 thread: 469ms
$ julia -t 2 mapreduce.jl # runtime with 2 threads: 266ms
$ julia -t 4 mapreduce.jl # runtime with 4 threads: 136ms
$ julia -t 8 mapreduce.jl # what should we expect?Using the compact (one-line) if-else notation, shorten this code by four lines. Time the new, shorter code with one and several threads.
Hint: the syntax is 1 > 2 ? "1 is greater than 2" : "1 is not greater than 2"
Parallelizing the slow series with ThreadsX.sum
?sum
sum(x->x^2, 1:10)
ThreadsX.sum(x->x^2, 1:10)
ThreadsX.sum(x^2 for x in 1:10) # alternative syntaxThe expression in the last round brackets is a generator. It generates a sequence on the fly without storing individual elements, thus taking very little memory.
(i for i in 1:10) # generator
collect(i for i in 1:10) # construct a vector (this one takes more space) from it
[i for i in 1:10] # functionally the same (vector via an array comprehension)Let’s use a generator with \(10^8\) elements to compute our slow series sum:
using BenchmarkTools
@btime sum(!digitsin(9, i) ? 1.0/i : 0 for i in 1:100_000_000)
# serial code: 553ms, prints 13.277605949858103It is very easy to parallelize:
using BenchmarkTools, ThreadsX
@btime ThreadsX.sum(!digitsin(9, i) ? 1.0/i : 0 for i in 1:100_000_000)
# with 4 threads: 135ms, prints 13.277605949854381The expression [i for i in 1:10 if i%2==1] produces an array of odd integers between 1 and 10. Using this syntax, remove terms that are zero from the last generator, i.e. write a parallel code for summing the slow series with a generator that contains only non-zero terms. It should run slightly faster than the code with the original generator. (I get 527.159 ms runtime.)
Finally, let’s rewrite our code applying a function to all integers in a range:
function numericTerm(i)
!digitsin(9, i) ? 1.0/i : 0
end
@btime ThreadsX.sum(numericTerm, 1:Int64(1e8)) # 135ms, same resultRewrite the last code replacing sum with mapreduce. Hint: look up help for mapreduce().
Other parallel functions
ThreadsX provides various parallel functions for sorting. Sorting is intrinsically hard to parallelize, so do not expect 100% parallel efficiency. Let’s take a look at sort() and sort!():
n = Int64(1e8)
r = rand(Float32, (n)); # random floats in [0, 1]
first(r,10) # first 20 elements, same as r[1:10]
last(r,10) # last 10 elements
?sort # underneath uses QuickSort (for numeric arrays) or MergeSort
@btime sort(r); # 523ms, serial sorting
@btime sort!(r); # 95ms, in-place serial sorting
r = rand(Float32, (n));
@btime ThreadsX.sort(r); # 998ms, parallel sorting with 4 threads
@btime ThreadsX.sort!(r); # 595ms, in-place parallel sorting with 4 threads
# parallel version is actually slower in recent Julia ...
?ThreadsX.sort! # there is a good manual page
# similar speedup for integers
r = rand(Int32, (n));
@btime sort!(r); # 32ms in serial
r = rand(Int32, (n));
@btime ThreadsX.sort!(r); # 342.983 ms with 4 threads
# parallel version is actually slower in recent Julia ...Searching for extrema is much more parallel-friendly:
n = Int64(1e9)
r = rand(Int32, (n)); # make sure we have enough memory
@btime maximum(r) # 67ms in serial
@btime ThreadsX.maximum(r) # 37ms with 4 threadsFinally, another useful function is ThreadsX.map() without reduction – we will take a closer look at it in one of the following sections.
To sum up this section, ThreadsX.jl provides a super easy way to parallelize some of the Base library functions. It includes multi-threaded reduction and shows very impressive parallel performance. To list the supported functions, use ThreadsX.<TAB>, and don’t forget to use the built-in help pages.