My favorites | English | Sign in

Sharding counters

Joe Gregorio
October 2008
, updated October 2009

This is part four of a five-part series on effectively scaling your App Engine-based apps. To see the other articles in the series, see Related links.

When developing an efficient application on Google App Engine, you need to pay attention to how often an entity is updated. While App Engine's datastore scales to support a huge number of entities, it is important to note that you can only expect to update any single entity or entity group about five times a second. That is an estimate and the actual update rate for an entity is dependent on several attributes of the entity, including how many properties it has, how large it is, and how many indexes need updating. While a single entity or entity group has a limit on how quickly it can be updated, App Engine excels at handling many parallel requests distributed across distinct entities, and we can take advantage of this by using sharding.

The question is, what if you had an entity that you wanted to update faster than five times a second? For example, you might count the number of votes in a poll, the number of comments, or even the number of visitors to your site. Take this simple example:

class Counter(db.Model):
  count = db.IntegerProperty()
      
@PersistenceCapable(identityType = IdentityType.APPLICATION)
public class Counter {
  @PrimaryKey
  @Persistent(valueStrategy = IdGeneratorStrategy.IDENTITY)
  private Long id;

  @Persistent
  private Integer count;

  public Long getId() {
    return id;
  }

  public Integer getCount() {
    return count;
  }

  // ...
}
      

If you had a single entity that was the counter and the update rate was too fast, then you would have contention as the serialized writes would stack up and start to timeout. The way to solve this problem is a little counter-intuitive if you are coming from a relational database; the solution relies on the fact that reads from the App Engine datastore are extremely fast and cheap since entities that have been recently read or updated are cached in memory. The way to reduce the contention is to build a sharded counter – break the counter up into N different counters. When you want to increment the counter, you pick one of the shards at random and increment it. When you want to know the total count, you read all of the counter shards and sum up their individual counts. The more shards you have, the higher the throughput you will have for increments on your counter. This technique works for a lot more than just counters and an important skill to learn is spotting the entities in your application with a lot of writes and then finding good ways to shard them.

Here is a very simple implemenation of a sharded counter:

from google.appengine.ext import db
import random

class SimpleCounterShard(db.Model):
    """Shards for the counter"""
    count = db.IntegerProperty(required=True, default=0)

NUM_SHARDS = 20

def get_count():
    """Retrieve the value for a given sharded counter."""
    total = 0
    for counter in SimpleCounterShard.all():
        total += counter.count
    return total

def increment():
    """Increment the value for a given sharded counter."""
    def txn():
        index = random.randint(0, NUM_SHARDS - 1)
        shard_name = "shard" + str(index)
        counter = SimpleCounterShard.get_by_key_name(shard_name)
        if counter is None:
            counter = SimpleCounterShard(key_name=shard_name)
        counter.count += 1
        counter.put()
    db.run_in_transaction(txn)
      
import PMF;

import java.util.List;
import java.util.Random;

import javax.jdo.PersistenceManager;

/**
 * This initial implementation simply counts all instances of the
 * SimpleCounterShard class in the datastore. The only way to increment the
 * counter is to add another shard (creating another entity in the datastore).
 */
public class ShardedCounter {

  private static final int NUM_SHARDS = 20;

  /**
   * Retrieve the value of this sharded counter.
   *
   * @return Summed total of all shards' counts
   */
  public int getCount() {
    int sum = 0;
    PersistenceManager pm = PMF.get().getPersistenceManager();

    try {
      String query = "select from " + SimpleCounterShard.class.getName();
      List<SimpleCounterShard> shards =
          (List<SimpleCounterShard>) pm.newQuery(query).execute();
      if (shards != null && !shards.isEmpty()) {
        for (SimpleCounterShard shard : shards) {
          sum += shard.getCount();
        }
      }
    } finally {
      pm.close();
    }

    return sum;
  }

  /**
   * Increment the value of this sharded counter.
   */
  public void increment() {
    PersistenceManager pm = PMF.get().getPersistenceManager();

    Random generator = new Random();
    int shardNum = generator.nextInt(NUM_SHARDS);

    try {
      Query shardQuery = pm.newQuery(SimpleCounterShard.class);
      shardQuery.setFilter("shardNumber == numParam");
      shardQuery.declareParameters("int numParam");

      List<SimpleCounterShard> shards =
          (List<SimpleCounterShard>) shardQuery.execute(shardNum);
      SimpleCounterShard shard;

      // If the shard with the passed shard number exists, increment its count
      // by 1. Otherwise, create a new shard object, set its count to 1, and
      // persist it.
      if (shards != null && !shards.isEmpty()) {
        shard = shards.get(0);
        shard.setCount(shard.getCount() + 1);
      } else {
        shard = new SimpleCounterShard();
        shard.setShardNumber(shardNum);
        shard.setCount(1);
      }

      pm.makePersistent(shard);
    } finally {
      pm.close();
    }
  }
}
      

In get_count() (Python) and getCount() (Java), we simply loop over all the shards and add up the individual shard counts. In increment(), we choose a shard at random and then read, increment, and write it back to the datastore.

Note that we create the shards lazily, only creating them when they are first incremented. The lazy creation of the shards allows the number of shards to be increased (but never decreased) in the future if more are needed. The value of NUM_SHARDS could be doubled and the results from get_count() would not change since the query only selects the shards that have been added to the datastore, and increment() will lazily create shards that aren't there.

That is useful as an example to learn from, but a more general purpose counter would allow you to create named counters on the fly, increase the number of shards dynamically, and use memcache to speed up reads to shards. The exampe code that Brett Slatkin gave in his Google I/O talk does just that and I've included that code here, along with a function to increase the number of shards for a particular counter:

from google.appengine.api import memcache
from google.appengine.ext import db
import random

class GeneralCounterShardConfig(db.Model):
    """Tracks the number of shards for each named counter."""
    name = db.StringProperty(required=True)
    num_shards = db.IntegerProperty(required=True, default=20)


class GeneralCounterShard(db.Model):
    """Shards for each named counter"""
    name = db.StringProperty(required=True)
    count = db.IntegerProperty(required=True, default=0)


def get_count(name):
    """Retrieve the value for a given sharded counter.

    Parameters:
      name - The name of the counter
    """
    total = memcache.get(name)
    if total is None:
        total = 0
        for counter in GeneralCounterShard.all().filter('name = ', name):
            total += counter.count
        memcache.add(name, str(total), 60)
    return total


def increment(name):
    """Increment the value for a given sharded counter.

    Parameters:
      name - The name of the counter
    """
    config = GeneralCounterShardConfig.get_or_insert(name, name=name)
    def txn():
        index = random.randint(0, config.num_shards - 1)
        shard_name = name + str(index)
        counter = GeneralCounterShard.get_by_key_name(shard_name)
        if counter is None:
            counter = GeneralCounterShard(key_name=shard_name, name=name)
        counter.count += 1
        counter.put()
    db.run_in_transaction(txn)
    memcache.incr(name)


def increase_shards(name, num):
    """Increase the number of shards for a given sharded counter.
    Will never decrease the number of shards.

    Parameters:
      name - The name of the counter
      num - How many shards to use

    """
    config = GeneralCounterShardConfig.get_or_insert(name, name=name)
    def txn():
        if config.num_shards < num:
            config.num_shards = num
            config.put()
    db.run_in_transaction(txn)
      
import PMF;

import java.util.List;
import java.util.Random;
import javax.jdo.PersistenceManager;
import javax.jdo.Query;

/**
 * A counter which can be incremented rapidly.
 *
 * Capable of incrementing the counter and increasing the number of shards.
 * When incrementing, a random shard is selected to prevent a single shard
 * from being written to too frequently. If increments are being made too
 * quickly, increase the number of shards to divide the load. Performs
 * datastore operations using JDO.
 */
public class ShardedCounter {
  private String counterName;

  public ShardedCounter(String counterName) {
    this.counterName = counterName;
  }

  public String getCounterName() {
    return counterName;
  }

  /**
   * Retrieve the value of this sharded counter.
   *
   * @return Summed total of all shards' counts
   */
  public int getCount() {
    int sum = 0;
    PersistenceManager pm = PMF.get().getPersistenceManager();

    try {
      Query shardsQuery =
          pm.newQuery(GeneralCounterShard.class, "counterName == nameParam");
      shardsQuery.declareParameters("String nameParam");

      List<GeneralCounterShard> shards =
          (List<GeneralCounterShard>) shardsQuery.execute(counterName);
      if (shards != null && !shards.isEmpty()) {
        for (GeneralCounterShard current : shards) {
          sum += current.getCount();
        }
      }
    } finally {
      pm.close();
    }

    return sum;
  }

  /**
   * Increment the value of this sharded counter.
   */
  public void increment() {
    PersistenceManager pm = PMF.get().getPersistenceManager();

    // Find how many shards are in this counter.
    int shardCount = 0;
    try {
      Counter current = getThisCounter(pm);
      shardCount = current.getShardCount();
    } finally {
      pm.close();
    }

    // Choose the shard randomly from the available shards.
    Random generator = new Random();
    int shardNum = generator.nextInt(shardCount);

    pm = PMF.get().getPersistenceManager();
    try {
      Query randomShardQuery = pm.newQuery(GeneralCounterShard.class);
      randomShardQuery.setFilter(
          "counterName == nameParam && shardNumber == numParam");
      randomShardQuery.declareParameters("String nameParam, int numParam");

      List<GeneralCounterShard> shards = (List<GeneralCounterShard>)
          randomShardQuery.execute(counterName, shardNum);
      if (shards != null && !shards.isEmpty()) {
        GeneralCounterShard shard = shards.get(0);
        shard.increment(1);
        pm.makePersistent(shard);
      }
    } finally {
      pm.close();
    }
  }

  /**
   * Increase the number of shards for a given sharded counter.
   * Will never decrease the number of shards.
   *
   * @param  count Number of new shards to build and store
   * @return Total number of shards
   */
  public int addShards(int count) {
    PersistenceManager pm = PMF.get().getPersistenceManager();

    // Find the initial shard count for this counter.
    int numShards = 0;
    try {
      Counter current = getThisCounter(pm);
      if (current != null) {
        numShards = current.getShardCount().intValue();
        current.setShardCount(numShards + count);
        // Save the increased shard count for this counter.
        pm.makePersistent(current);
      }
    } finally {
      pm.close();
    }

    // Create new shard objects for this counter.
    pm = PMF.get().getPersistenceManager();
    try {
      for (int i = 0; i < count; i++) {
        GeneralCounterShard newShard =
            new GeneralCounterShard(getCounterName(), numShards);
        pm.makePersistent(newShard);
        numShards++;
      }
    } finally {
      pm.close();
    }

    return numShards;
  }

  /**
   * @return Counter datastore object matching this object's counterName value
   */
  private Counter getThisCounter(PersistenceManager pm) {
    Counter current = null;

    Query thisCounterQuery =
        pm.newQuery(Counter.class, "counterName == nameParam");
    thisCounterQuery.declareParameters("String nameParam");

    List<Counter> counter =
        (List<Counter>) thisCounterQuery.execute(counterName);
    if (counter != null && !counter.isEmpty()) {
      current = counter.get(0);
    }

    return current;
  }
}
      

Source

The Python source for both counters described above is available in the Google App Engine Samples project as sharded-counter. The Java source is available in the demos directory of the Google App Engine SDK project as shardedcounter. While the web interface to the examples isn't much to look at, it's instructive to use the admin interface an inspect the data models after you have incremented both counters a few times.

Conclusion

Sharding is one of many important techniques in building a scalable application and hopefully these examples will give you ideas of where you apply the technique in your application. The code in these articles is available under the Apache 2 license so feel free to start with them as you build your solutions.

More Info

Watch Brett Slatkin's Google I/O talk "Building Scalable Web Applications with Google AppEngine".