Distributed task queues for machine learning in Python – Celery, RabbitMQ, Redis
Lately I've been evaluating a couple of different distributed tasks queues for Python. The purpose of this was to find a way to distribute some model simulations among 200 to 1000 machines to speed up parameter estimation.
The first candidate which comes to mind in the Python ecosystem is Celery. Celery is a widely used distributed task queue and supports a number of broker backends, including, but not limited to, RabbitMQ and Redis. Therefore, I also wanted to compare Celery to raw Redis and RabbitMQ task queue implementations.
The benchmark task
This benchmark was really about messaging performance. During a real run, most of the time would be spend simulating the model. But I wanted to know how much overhead the message queue produces. So I came up with this rather simple micro benchmark:
- The master node emits 100000 tasks. Each tasks is just a tuple of two integer number.
- The workers' job is to add these two numbers together and return the sum as result to the master
- I had 400 cores available. This should be enough to evaluate the networking overhead. In a real application I was planning to use something between 200 and 1000 worker machines.
- The broker run on another machine. The broker mediates between the master and the worker pool.
- I repeated the benchmark for every experiment 3 times. However, as I discovered, the variance was negligible, so I won't report it in the following
Benchmark results
RabbitMQ
Lets start with RabbitMQ. I spawned 400 workers, each using a single core, and submitted the tasks on the master. The tasks were submitted in only 9 seconds. The overall round trip took then 93 seconds. So in a minute and a half all the tasks were completed and the results returned to the master.
Redis
Now the same thing for Redis. Submission was a little slower, it took 19 seconds. But the whole round trip was completed in 73 seconds. So Redis has a slight edge here.
Celery with RabbitMQ broker
Next I went on to check how Celery compared the raw Redis and raw RabbitMQ. How much overhead would be involved? As already mentioned, Celery supports a RabbitMQ broker and a Redis broker and even allows to combine these two. So RabbitMQ can be used for scheduling and Redis for the results. In the following benchmarks, however, I used the same backend for both, messaging and results.
In a first attempt, I tried to spawn 400 celery workers with 1 core each. Sadly, this did not work. I got timeouts instead. Reliable communication was not possible. However, the worker pool size per node can be adjusted in Celery. So I tried to spawn 50 workers with 8 cores each. This time it worked. Compared to the raw RabbitMQ solution, Celery was much slower. Job submission took 222 seconds alone. The round trip was completed in 254 seconds.
But note, that Celery does a little more than only pushing jobs to a queue and fetching the results again. It includes heartbeat signals, message acknowledgments and can automatically restart jobs in case of hardware failure on the workers and even in the case of hardware failure on the broker if RabbitMQ is configured accordingly. So while it is slower, it's also doing more.
Celery with Redis broker
Again, I tried to start 400 workers with one core each. And again this approach failed with Celery. The Python Redis client (used by Celery) on the master threw and Exception complaining about too many connections to the Redis server. So as before I tried to spawn 50 workers with 8 cores each instead.
This time it worked. The jobs were submitted in 84 seconds, the whole round trip was completed in 122 seconds. Surprisingly the Redis broker was faster than the RabbitMQ broker. I was surprised because Celery was historically primarily developed for the RabbitMQ broker. But Redis seems to be a good choice as well. But again, there is a trade-off. Redis does not have the same level of resistance to hardware failure as the RabbitMQ backend. If you never want to loose a message, RabbitMQ might be the better choice.
What is next?
This comparison is not exhaustive, for example I did not evaluate the Python RQ queue. I also evaluated and benchmarked ZeroMQ. Which was blazingly fast, outperforming all the other solutions here by an order of magnitude. However, I did not include it here since the most natural architecture for ZeroMQ would be a brokerless one. Thus, the comparison would not be really fair.
And which queue to use now? I'd go with Celery for the moment. It just comes with many nice monitoring tools and is really easy to use. It is much more high level than the other solutions. With Celery and the Redis broker, 1000 tasks per second are feasible. For my applications this is really good enough. Looking at my model runtimes, I would expect no more than 40 evaluations per second on 400 cores.
Appendix - The benchmark code
RabbitMQ
Here is the worker code:
import pika import time connection = pika.BlockingConnection(pika.ConnectionParameters( host='broker')) channel = connection.channel() channel.queue_declare(queue='task_queue', durable=True) channelr = connection.channel() channelr.queue_declare(queue='result_queue', durable=True) def callback(ch, method, properties, body): print(body) a, b = list(map(int,body.decode()[1:-1].split(","))) res = str(a+b) channelr.basic_publish(exchange='', routing_key='result_queue', body=res, properties=pika.BasicProperties( delivery_mode = 2, # make message persistent )) ch.basic_ack(delivery_tag = method.delivery_tag) channel.basic_qos(prefetch_count=1) channel.basic_consume(callback, queue='task_queue') channel.start_consuming()
and the master code:
import pika import time connection = pika.BlockingConnection(pika.ConnectionParameters('broker')) channel = connection.channel() channel.queue_declare(queue='task_queue', durable=True) channel.basic_qos(prefetch_count=1) N = 100000 start = time.time() for k in range(N): channel.basic_publish(exchange='', routing_key='task_queue', body=str((1, k)), properties=pika.BasicProperties( delivery_mode = 2, # make message persistent )) send_finish = time.time() channelr = connection.channel() channelr.queue_declare(queue='result_queue', durable=True) k = 0 def callback(ch, method, properties, body): global k ch.basic_ack(delivery_tag = method.delivery_tag) if k == N-1: end = time.time() print("rabbit", send_finish - start, end - start) k += 1 channelr.basic_qos(prefetch_count=1) channelr.basic_consume(callback, queue='result_queue') channelr.start_consuming()
Redis
Worker:
import redis r = redis.Redis(host="broker", decode_responses=True) while True: res = r.blpop("tasks")[1] a, b = list(map(int,res[1:-1].split(","))) r.lpush("results", a + b)
Master:
import redis import time r = redis.Redis(host="broker", decode_responses=True) start = time.time() N = 100000 for k in range(N): r.lpush("tasks", (1, k)) push_end = time.time() results = [] for k in range(N): results.append(r.blpop("results")) end = time.time() print("redis", push_end - start, end - start)
Celery RabbitMQ
Task:
# celery_task.py from celery import Celery app = Celery('tasks', broker='amqp://guest@broker//', backend="rpc://") app.conf.task_reject_on_worker_lost = True app.conf.task_acks_late = True @app.task def add(x, y): return x + y
Master:
import celery_task as c from time import time, sleep tasks_t = time() tasks = [] N = 100000 for k in range(N): tasks.append(c.add.delay(2,k)) ready = [] collect_start = time() for k, t in enumerate(tasks): res = t.wait() print(res) ready.append(res) done = time() print("celery", collect_start - tasks_t, done - tasks_t)
Celery Redis
Task:
from celery import Celery app = Celery('tasks', broker='redis://broker', backend="redis://gaba02") app.conf.task_reject_on_worker_lost = True app.conf.task_acks_late = True @app.task def add(x, y): return x + y
Master: Same as in Celery RabbitMQ