June 27, 2010

Ordered Java Multi-channel Asynchronous Throttler

Tags: ,

Some time ago I wrote a post describing a Java Multi-channel Asynchronous Throttler I had written. At the time, I stated it would preserve the order of calls, but as Asa commented on that blog post, this was not always the case. Here is a new version that does preserve order, and passes Asa’s test. As part of this work I also extracted common code into new classes and created a ChannelThrottler interface. It works by placing incoming tasks on an internal queue. All the code detailed in this post (and the other throttler post) is available here.

I’m defining a multi-channel throttler as a class that inputs other Runnable tasks (together with an optional channel identifier) and throttles the rate they are executed. The throttling depends on rules defined as allowing X calls in Y time period. Different channels may have different rates, but there is also an overall rate. For more information see the previous blog post. So the interface for a generic channel throttler becomes the below:

//imports skipped for prettier code display
public interface ChannelThrottler {
	Future<?> submit(Runnable task);
	Future<?> submit(Object channelKey, Runnable task);
}

From the previous work, I extracted the Rate as an independent class (previously it was an inner class). It keeps track of an individual channel throttle rate. Using this information it can calculate the delay that needs to be applied to a Runnable task for it to meet the throttling requirements.

//imports skipped for prettier code display
public final class Rate {

	private final int numberCalls;
	private final int timeLength;
	private final TimeUnit timeUnit;
	private final LinkedList<Long> callHistory = new LinkedList<Long>();
	
	public Rate(int numberCalls, int timeLength, TimeUnit timeUnit) {
		this.numberCalls = numberCalls;
		this.timeLength = timeLength;
		this.timeUnit = timeUnit;
	}
	
	private long timeInMillis() {
		return timeUnit.toMillis(timeLength);
	}

	
	/* package */ void addCall(long callTime) {
		callHistory.addLast(callTime);
	}
	
	private void cleanOld(long now) {
		ListIterator<Long> i = callHistory.listIterator();
		long threshold = now-timeInMillis();
		while (i.hasNext()) {
			if (i.next()<=threshold) {
				i.remove();
			} else {
				break;
			}
		}
	}
	
	/* package */ long callTime(long now) {
		cleanOld(now);
		if (callHistory.size()<numberCalls) {
			return now;
		}
		long lastStart = callHistory.getLast()-timeInMillis();
		long firstPeriodCall=lastStart, call;
		int count = 0;
		Iterator<Long> i = callHistory.descendingIterator();
		while (i.hasNext()) {
			call = i.next();
			if (call<lastStart) {
				break;
			} else {
				count++;
				firstPeriodCall = call;
			}
		}
		if (count<numberCalls) {
			return firstPeriodCall+1;
		} else {
			return firstPeriodCall+timeInMillis()+1;
		}
	}
}

I also extracted some code common to both implementations (this one and the previous one) as an abstract super class.

//imports skipped for prettier code display
/* package */ abstract class AbstractChannelThrottler implements ChannelThrottler {

	protected final Rate totalRate;
	protected final TimeProvider timeProvider;
	protected final ScheduledExecutorService scheduler;
	protected final Map<Object, Rate> channels = new HashMap<Object, Rate>();
	
	protected AbstractChannelThrottler(Rate totalRate, ScheduledExecutorService scheduler, Map<Object, Rate> channels, TimeProvider timeProvider) {
		this.totalRate = totalRate;
		this.scheduler = scheduler;
		this.channels.putAll(channels);
		this.timeProvider = timeProvider;
	}
	
	protected synchronized long callTime(Rate channel) {
		long now = timeProvider.getCurrentTimeInMillis();
		long callTime = totalRate.callTime(now);
		if (channel!=null) {
			callTime = Math.max(callTime, channel.callTime(now));
			channel.addCall(callTime);
		}
		totalRate.addCall(callTime);
		return callTime;			
	}
	
	protected long getThrottleDelay(Object channelKey) {
		long delay = callTime(channels.get(channelKey))-timeProvider.getCurrentTimeInMillis();
		return delay<0?0:delay;
	}
}

Now for the order-preserving throttler itself. The problem with the previous throttler was that it worked by calculating the delay necessary on a Runnable task to meet the various throttle rates and then schedule its execution just after that delay. This meant that if multiple tasks were scheduled at the same time the JVM would run them in any order it chose. To get around this, the tasks are now stored in a FIFO queue. The appropriate delay to fulfill the rate requirements is still calculated, but now it is used to schedule a call that takes the first task off the queue and executes it. Note that the queue contains FutureTask objects and the input Runnable tasks are converted to a FutureTask. This is to maintain the proper interface and allow the process calling the throttler to see the progress of the task (or cancel it).

public final class QueueChannelThrottler extends AbstractChannelThrottler {
	
	private final Runnable processQueueTask = new Runnable() {
		@Override public void run() {
			FutureTask<?> task = tasks.poll();
			if (task!=null && !task.isCancelled()) {
				task.run();
			}
		}		
	};
	private final Queue<FutureTask<?>> tasks = new LinkedList<FutureTask<?>>();

	public QueueChannelThrottler(Rate totalRate) {
		this(totalRate, Executors.newSingleThreadScheduledExecutor(), new HashMap<Object, Rate>(), TimeProvider.SYSTEM_PROVIDER);
	}
	
	public QueueChannelThrottler(Rate totalRate, Map<Object, Rate> channels) {
		this(totalRate, Executors.newSingleThreadScheduledExecutor(), channels, TimeProvider.SYSTEM_PROVIDER);
	}
	
	public QueueChannelThrottler(Rate totalRate, ScheduledExecutorService scheduler, Map<Object, Rate> channels, TimeProvider timeProvider) {
		super(totalRate, scheduler, channels, timeProvider);
	}
	
	@Override public Future<?> submit(Runnable task) {
		return submit(null, task);
	} 
	
	@SuppressWarnings("unchecked")
	@Override public Future<?> submit(Object channelKey, Runnable task) {
		long throttledTime = channelKey==null?callTime(null):callTime(channels.get(channelKey));
		FutureTask runTask = new FutureTask(task, null);
		tasks.add(runTask);
		long now = timeProvider.getCurrentTimeInMillis();
		scheduler.schedule(processQueueTask, throttledTime<now?0:throttledTime-now, TimeUnit.MILLISECONDS);
		return runTask;
	} 

}

All the code detailed in this post (and the other throttler post) is available here. For reference, here are the tests used. They are very similar to the tests used for the previous throttler. However, I have added Asa’s ordered calls test.

//imports skipped for prettier code display
public class QueueChannelThrottlerTest {

	private static final String CHANNEL1 = "CHANNEL1";
	private static final String CHANNEL2 = "CHANNEL2";
	private DeterministicScheduler scheduler;
	private AtomicLong currentTime = new AtomicLong(0);
	private QueueChannelThrottler throttler;
	private AtomicInteger count = new AtomicInteger(0);
	private Runnable countIncrementTask = new Runnable() {@Override public void run() {count.incrementAndGet();}};
	
	@SuppressWarnings("serial")
	@Before public void setupThrottler() {
		scheduler = new DeterministicScheduler();
		currentTime.set(0);	
		Map<Object, Rate> channels = new HashMap<Object, Rate>();
		put(CHANNEL1, new Rate(3, 1, TimeUnit.SECONDS));
		put(CHANNEL2, new Rate(1, 1, TimeUnit.SECONDS));
		throttler = new QueueChannelThrottler(new Rate(2, 1, TimeUnit.SECONDS), scheduler, channels, new TimeProvider() {
			@Override public long getCurrentTimeInMillis() {return currentTime.get();}		
		});
		count = new AtomicInteger(0);
	}	
	
	@Test public void testTotalChannelWithNoDelay() throws Exception {
		throttler.submit(countIncrementTask);
		throttler.submit(countIncrementTask);
		scheduler.tick(1, TimeUnit.MILLISECONDS);
		assertEquals(2, count.get());
	}
	
	@Test public void testTotalChannelWithDelay() throws Exception {
		throttler.submit(countIncrementTask);
		throttler.submit(countIncrementTask);
		throttler.submit(countIncrementTask);
		scheduler.tick(1, TimeUnit.MILLISECONDS);
		assertEquals(2, count.get());
		scheduler.tick(1000, TimeUnit.MILLISECONDS);
		assertEquals(3, count.get());
	}
	
	@Test public void testTotalChannelWithDoubleDelay() throws Exception {
		throttler.submit(countIncrementTask);
		throttler.submit(countIncrementTask);
		throttler.submit(countIncrementTask);
		throttler.submit(countIncrementTask);
		throttler.submit(countIncrementTask);
		scheduler.tick(1, TimeUnit.MILLISECONDS);
		assertEquals(2, count.get());
		scheduler.tick(500, TimeUnit.MILLISECONDS);
		assertEquals(2, count.get());
		scheduler.tick(500, TimeUnit.MILLISECONDS);
		assertEquals(3, count.get());
		scheduler.tick(1, TimeUnit.MILLISECONDS);
		assertEquals(4, count.get());
		scheduler.tick(1000, TimeUnit.MILLISECONDS);
		assertEquals(5, count.get());
	}
	
	@Test public void testTotalChannelWithShortestDelay() throws Exception {
		throttler.submit(countIncrementTask);
		currentTime = new AtomicLong(777);	
		scheduler.tick(777, TimeUnit.MILLISECONDS);
		throttler.submit(countIncrementTask);
		throttler.submit(countIncrementTask);
		currentTime = new AtomicLong(877);
		scheduler.tick(100, TimeUnit.MILLISECONDS);
		throttler.submit(countIncrementTask);
		
		assertEquals(2, count.get());
		scheduler.tick(124, TimeUnit.MILLISECONDS);
		assertEquals(3, count.get());
		scheduler.tick(777, TimeUnit.MILLISECONDS);
		assertEquals(4, count.get());
	}
	
	@Test public void testChannel() throws Exception {
		throttler.submit(CHANNEL2, countIncrementTask);
		currentTime = new AtomicLong(777);	
		scheduler.tick(777, TimeUnit.MILLISECONDS);
		throttler.submit(CHANNEL2, countIncrementTask);
		currentTime = new AtomicLong(877);
		scheduler.tick(100, TimeUnit.MILLISECONDS);
		throttler.submit(CHANNEL2, countIncrementTask);
		
		assertEquals(1, count.get());
		scheduler.tick(124, TimeUnit.MILLISECONDS);
		assertEquals(2, count.get());
		scheduler.tick(1000, TimeUnit.MILLISECONDS);
		assertEquals(2, count.get());
		scheduler.tick(1, TimeUnit.MILLISECONDS);
		assertEquals(3, count.get());
	}
	
	@Test public void testChannelAndTotal() throws Exception {
		throttler.submit(CHANNEL1, countIncrementTask);
		currentTime = new AtomicLong(777);	
		scheduler.tick(777, TimeUnit.MILLISECONDS);
		throttler.submit(CHANNEL1, countIncrementTask);
		throttler.submit(CHANNEL1, countIncrementTask);
		currentTime = new AtomicLong(877);
		scheduler.tick(100, TimeUnit.MILLISECONDS);
		throttler.submit(CHANNEL1, countIncrementTask);
		
		assertEquals(2, count.get());
		scheduler.tick(124, TimeUnit.MILLISECONDS);
		assertEquals(3, count.get());
		scheduler.tick(777, TimeUnit.MILLISECONDS);
		assertEquals(4, count.get());
	}
	
	@Test public void testChannelAffectsTotal() throws Exception {
		throttler.submit(CHANNEL1, countIncrementTask);
		currentTime = new AtomicLong(777);	
		scheduler.tick(777, TimeUnit.MILLISECONDS);
		throttler.submit(CHANNEL1, countIncrementTask);
		throttler.submit(countIncrementTask);
		currentTime = new AtomicLong(877);
		scheduler.tick(100, TimeUnit.MILLISECONDS);
		throttler.submit(CHANNEL1, countIncrementTask);
		
		assertEquals(2, count.get());
		scheduler.tick(124, TimeUnit.MILLISECONDS);
		assertEquals(3, count.get());
		scheduler.tick(777, TimeUnit.MILLISECONDS);
		assertEquals(4, count.get());
	}
	
	private class OrderedTask implements Runnable {
		private final int order;
		public OrderedTask(int order) {this.order=order;}
		@Override public void run() {
			assertEquals(count.incrementAndGet(), order);
		}
	};
	
	@Test public void testChannelCallsAreOrdered() throws Exception {
		throttler.submit(CHANNEL1, new OrderedTask(1));
		throttler.submit(CHANNEL2, new OrderedTask(2));
		throttler.submit(CHANNEL1, new OrderedTask(3));
		throttler.submit(CHANNEL2, new OrderedTask(4));
		throttler.submit(CHANNEL2, new OrderedTask(5));
		throttler.submit(CHANNEL1, new OrderedTask(6));
		throttler.submit(CHANNEL1, new OrderedTask(7));
		scheduler.tick(5000, TimeUnit.MILLISECONDS);
		assertEquals(7, count.get());
	}
	
	@Test public void testMultiChannel() throws Exception {
		throttler.submit(CHANNEL1, countIncrementTask);
		currentTime = new AtomicLong(777);	
		scheduler.tick(777, TimeUnit.MILLISECONDS);
		throttler.submit(CHANNEL2, countIncrementTask);
		throttler.submit(CHANNEL1, countIncrementTask);
		currentTime = new AtomicLong(877);
		scheduler.tick(100, TimeUnit.MILLISECONDS);
		throttler.submit(CHANNEL2, countIncrementTask);
		throttler.submit(CHANNEL2, countIncrementTask);
		throttler.submit(CHANNEL1, countIncrementTask);
		throttler.submit(CHANNEL1, countIncrementTask);
		
		assertEquals(2, count.get());
		scheduler.tick(123, TimeUnit.MILLISECONDS);
		assertEquals(2, count.get());
		scheduler.tick(1, TimeUnit.MILLISECONDS);
		assertEquals(3, count.get());
		scheduler.tick(778, TimeUnit.MILLISECONDS);
		assertEquals(4, count.get());
		scheduler.tick(1001, TimeUnit.MILLISECONDS);
		assertEquals(6, count.get());
		scheduler.tick(999, TimeUnit.MILLISECONDS);
		assertEquals(6, count.get());
		scheduler.tick(1, TimeUnit.MILLISECONDS);
		assertEquals(7, count.get());
	}
	
	@Test public void testMultiChannelWithTotal() throws Exception {
		throttler.submit(CHANNEL1, countIncrementTask);
		currentTime = new AtomicLong(777);	
		scheduler.tick(777, TimeUnit.MILLISECONDS);
		throttler.submit(CHANNEL2, countIncrementTask);
		throttler.submit(countIncrementTask);
		currentTime = new AtomicLong(877);
		scheduler.tick(100, TimeUnit.MILLISECONDS);
		throttler.submit(CHANNEL2, countIncrementTask);
		throttler.submit(CHANNEL2, countIncrementTask);
		throttler.submit(countIncrementTask);
		throttler.submit(CHANNEL1, countIncrementTask);
		
		assertEquals(2, count.get());
		scheduler.tick(123, TimeUnit.MILLISECONDS);
		assertEquals(2, count.get());
		scheduler.tick(1, TimeUnit.MILLISECONDS);
		assertEquals(3, count.get());
		scheduler.tick(778, TimeUnit.MILLISECONDS);
		assertEquals(4, count.get());
		scheduler.tick(1001, TimeUnit.MILLISECONDS);
		assertEquals(6, count.get());
		scheduler.tick(999, TimeUnit.MILLISECONDS);
		assertEquals(6, count.get());
		scheduler.tick(1, TimeUnit.MILLISECONDS);
		assertEquals(7, count.get());
	}
	
	@Test
	public void scheduledTasksMonotonicallyIncreasing(){
		int numCalls = 50;
		int totalCalls = 1000;
		int ratePeriod = 100;
		final CountDownLatch latch = new CountDownLatch(totalCalls);
		Rate rate = new Rate(numCalls, ratePeriod, TimeUnit.MILLISECONDS);
		QueueChannelThrottler throttler = new QueueChannelThrottler(rate);
		final ConcurrentLinkedQueue<Long> base = new ConcurrentLinkedQueue<Long>();
		
		for(int i = 0; i < totalCalls; i++) {
			throttler.submit(new Runnable() {
				@Override public void run() {
					base.add(System.currentTimeMillis());
					latch.countDown();
				}			
			});
		}
	
		// wait for the tasks to finish, before exiting
		try {
			latch.await((totalCalls/numCalls)*ratePeriod, TimeUnit.MILLISECONDS);
		} catch (InterruptedException e) {
			fail();
		}

		assertEquals(base.size(), 1000);
		long last = 0;
		for (Long next: base) {
			assertTrue(next >= last);
			last = next;
		}
	}
	
	@Test
	public void scheduledTasksShouldRunInOrder(){
		int numCalls = 50;
		int totalCalls = 1000;
		int ratePeriod = 100;
		CountDownLatch latch = new CountDownLatch(totalCalls);
		Rate rate = new Rate(numCalls, ratePeriod, TimeUnit.MILLISECONDS);
		QueueChannelThrottler throttler = new QueueChannelThrottler(rate);
		ConcurrentLinkedQueue<Integer> base = new ConcurrentLinkedQueue<Integer>();
		ConcurrentLinkedQueue<Integer> toCompare = new ConcurrentLinkedQueue<Integer>();
		
		for(int i = 0; i < totalCalls; i++) {
			throttler.submit(new RunnableImpl(i, toCompare, latch));
			base.add(i);
		}
	
		// wait for the tasks to finish, before exiting
		try {
			latch.await((totalCalls/numCalls)*ratePeriod, TimeUnit.MILLISECONDS);
		} catch (InterruptedException e) {
			fail();
		}

		assertEquals(toCompare.size(), base.size());
		for (int i=0; i<toCompare.size(); i++) {
			assertEquals(toCompare.poll(), base.poll());
		}
		//assertEquals(toCompare, base);
	}

	public class RunnableImpl implements Runnable {
		public final int id;
		private final ConcurrentLinkedQueue<Integer> collection;
		private final CountDownLatch latch;
		
		public RunnableImpl(int id, ConcurrentLinkedQueue<Integer> collection, CountDownLatch latch) {
			this.id = id;
			this.collection = collection;
			this.latch = latch;
		}
		
		public void run() {
			collection.add(id);
			latch.countDown();
		}
	}
}

Update 12 Feb 2012: I was recently asked to specify the license for the code on this page. Searching the web for a license that accurately captured my feelings, I found the WTFPL. The code on this page and at the download link :http://www.cordinc.com/projects/throttler.zip is available under the WTFPL. Basically, you can do what you want with the code on this page, I place no restrictions on its use, but neither are there any guarantees. It solved the problem I had, but I can’t vouch for any other usage. Use at your own risk (but if you find a bug I’d be interested to hear about it).


comments powered by Disqus