Home Blog Implementing a semaphore with Django and Postgres

Implementing a semaphore with Django and Postgres

Posted by gavin on Sept. 12, 2017, 11:12 a.m.

DEVELOPMENT

Recently I was working on a project that accesses a third-party webservice.  The service does not handle concurrency well, and crashes if you send more than a few requests at the same time. To prevent it from crashing, I needed to make sure we don't send it too many requests at once. Fixing the service was not an option. Since it was called in a synchronous flow where the user expected a response right away, it wasn't possible to queue the requests either.  Rather then sending too many requests to the service, it was preferable just to give the user an error immediately.

The data structure to ensure you never have more than a set number of processes accessing a resource in a concurrent environment is called a semaphore. A semaphore is initialized with a maximum number of accessors, and then tracks how many processes are currently accessing the resource. When a process wants to access the resource, it tries to claim one of the slots in the semaphore, failing if there are already too many processes.

A semaphore requires shared mutable state that can be accessed in a concurrency-safe way. In a Django application, the database is the most convenient place to do this. We can use transactions, locks, and constraints to safely implement the invariants a semaphore needs.

We start with a model to track requests in progress:

1
2
3
4
5
class ServiceRequest(models.Model):
    created_at = models.DateTimeField()
    completed_at = models.DateTimeField(null=True)
    timed_out_at = models.DateTimeField(null=True)
    slot = models.PositiveSmallIntegerField()

When we want to claim a semaphore slot, we insert a row where completed_at is null and slot is the lowest number that's not already used by any of the other rows in the table where completed_at is null. If there are no unused slot numbers less than number of requests the semaphore allows, we raise an exception. When the request returns, we set completed_at to the current time We need timed_out_at to handle failure. If a process crashes, we need a way to release that slot after a timeout.

The code to do all that looks like this:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class NoSlotsAvailableException(Exception):
    pass

NUM_SLOTS = 4
TIMEOUT = datetime.timedelta(seconds=90)

@classmethod
def timeout(cls):
    """
    Cleans out slots that have been claimed for too long, assuming that
    their holder has crashed.
    """
    cls.objects.filter(
        created_at__lt=TransactionNow() - cls.TIMEOUT,
        completed_at=None,
        timed_out_at=None,
    ).update(
        timed_out_at=TransactionNow()
    )

@classmethod
def claim(cls):
    assert not connection.in_atomic_block, "If this is in a transaction, it will block too much"
    cls.timeout()
    with connection.cursor() as cursor, transaction.atomic():
        # We just need a lock level that conflicts with itself and the ROW
        # EXCLUSIVE lock taken by the insert.
        cursor.execute('LOCK TABLE %s IN SHARE ROW EXCLUSIVE MODE;' % (cls._meta.db_table,))
        objects = list(cls.objects.raw(
            """
            INSERT INTO "{table}" (created_at, slot)
            SELECT now(), coalesce(min(slot), 0)
            FROM generate_series(0, %s) slot
            WHERE NOT EXISTS (
                SELECT *
                FROM "{table}"
                WHERE "{table}".slot = slot.slot
                 AND completed_at IS NULL
                 AND timed_out_at IS NULL
           )
            ON CONFLICT DO NOTHING
            RETURNING *;
            """.format(table=cls._meta.db_table),
            [cls.NUM_SLOTS - 1]
        ))
        if objects:
            return objects[0]
        else:
            raise NoSlotsAvailableException

def release(self):
    self.completed_at = TransactionNow()
    self.save()

To make this code safe, we have to have a unique constraint on slot. Since we only care about uniqueness about requests that are actually in flight, we have to use a partial unique index. An alternative would be to delete rows from the table instead of marking completed_at, but in my case it's useful to have the log of requests that have been done. Django does not support adding partial unique constraints, but we can do it with raw SQL in a migration:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
operations = [
    migrations.RunSQL(
        """
        CREATE UNIQUE INDEX one_per_slot ON someapp_servicerequest (slot) WHERE (
            completed_at IS NULL AND timed_out_at IS NULL
        );
        """,
        """
        DROP INDEX one_per_slot;
        """
    )
]

Now postgres will enforce our constraint that slot numbers can't be reused for requests in flight.

The application code looks like this:

1
2
3
4
5
6
7
try:
    slot = ServiceRequest.claim()
    # Do request. Make sure it has a timeout of less than 90 seconds so we
    # don't erroneously release the slot.
    slot.release()
except NoSlotsAvailableException:
    # Return error to the user saying the system is currently overloaded.

This solution is simple and doesn't require any other components, but is probably not scalable if you are claiming and releasing the semaphore many times per second. Is has worked well for me while doing about 3000 requests per hour.