@Throttled: the CDI extension


In JavaEE the throttling is often done using a stateless bean cause they are by designed pooled and the pool provide a contention point. This is however IMO a workaround more than a solution for the throttling need and a small CDI extension can be worth it.

The first drawback of such a solution is you impact the beans management which is actually not the goal: it means depending the injections you have you can get an unexpected behavior and use too much resources. It also implies a particular lifecycle of the beans – in particular in case of exception – which can be unlikely depending the application. Finally it means you can’t store any state in your bean and need to delegate it if needed cause you don’t have a single instance.

To make the code more aligned on the original goal which is just to prevent too much concurrent access we’ll write a @Throttled extension.

Throttling: semaphore to the rescue

A naive but simple and working implementation is to use a semaphore. High level view of a semaphore is a counter not able to go under 0. This means this is exactly what we expect from a throttler.

The Semaphore API is also smart enough to support a timeout (ie try to get access during N unit of time) which avoids to lock forever.

API: @Throttled

We’ll keep the API very simple and just allow to configure the timeout we want to wait the permission to access the method and a number of permissions the method should have. This would allow to add a weight to accesses:

import javax.enterprise.util.Nonbinding;
import javax.interceptor.InterceptorBinding;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import java.util.concurrent.TimeUnit;

import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

@InterceptorBinding
@Retention(RUNTIME)
@Target({TYPE, METHOD})
public @interface Throttled {
    @Nonbinding
    long timeout() default 0L;

    @Nonbinding
    TimeUnit timeoutUnit() default TimeUnit.MILLISECONDS;

    @Nonbinding
    int weight() default 1;
}

This is the access API but we also need to configure which semaphore we use and what is his configuration. For that we add a companion annotation we’ll call @Throttling:

import javax.enterprise.util.Nonbinding;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;

import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

@Retention(RUNTIME)
@Target({TYPE, METHOD})
public @interface Throttling {
    @Nonbinding
    Class<? extends SemaphoreFactory> factory() default SemaphoreFactory.class;

    @Nonbinding
    boolean fair() default false;

    @Nonbinding
    int permits() default 1;

    @Nonbinding // if empty use class name
    String name() default "";
}

This one defines which semaphore we use (name), what is its configuration (fair and permits, check Semaphore javadoc for more details) and a SemaphoreFactory. This one is there if you want to not use a default semaphore but a custom one….wait before we said Semaphore was perfect for a throttler? Yes but it is not distributed so this factory would allow to switch to a Hazelcast semaphore for instance making the throttling working in a cluster as well in a (almost) transparent manner.

Last part of the API the SemaphoreFactory, this one is pretty trivial and just returns a Semaphore from the parameters of the @Throttling annotation:

public interface SemaphoreFactory {
    Semaphore newSemaphore(AnnotatedMethod<?> method, String name, boolean fair, int permits);
}

The Throttler implementation

You probably noticed the @InterceptorBinding on the @Throttled annotation so you know we miss the interceptor wrapping the method call in a semphore access to have our implementation complete.

As often with interceptor we’ll delegate most of the logic to an internal class @ApplicationScoped to avoid to recreate the semaphore and break our throttling if the throttled beans are not @ApplicationScoped (@RequestScoped for instance) and to not recompute all the metadata for each invocation.

The implementation simply extract @Throttled and @Throttling annotations (if there for this last one) and call once by method the semaphore factory. Then it just precomputes the invocation wrapper which acquires the number of permits of the @Throttled annotation with or without the timeout and then release them.

Lost? Here is the code:

import javax.annotation.Priority;
import javax.enterprise.context.ApplicationScoped;
import javax.enterprise.inject.Typed;
import javax.enterprise.inject.spi.AnnotatedMethod;
import javax.enterprise.inject.spi.AnnotatedType;
import javax.enterprise.inject.spi.BeanManager;
import javax.inject.Inject;
import javax.interceptor.AroundInvoke;
import javax.interceptor.Interceptor;
import javax.interceptor.InvocationContext;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;

import static java.util.Optional.ofNullable;

@Throttled
@Interceptor
@Priority(Interceptor.Priority.LIBRARY_BEFORE)
public class ThrottledInterceptor implements Serializable {
    @Inject
    private LocalCache metadata;

    @AroundInvoke
    public Object invoke(final InvocationContext ic) throws Exception {
        return metadata.getOrCreateInvocation(ic).invoke(ic);
    }

    private static Semaphore onInterruption(final InterruptedException e) {
        Thread.interrupted();
        throw new IllegalStateException("acquire() interrupted", e);
    }

    @ApplicationScoped
    @Typed(LocalCache.class)
    static class LocalCache implements SemaphoreFactory {
        private final ConcurrentMap<String, Semaphore> semaphores = new ConcurrentHashMap<>();
        private final ConcurrentMap<Method, Invocation> providers = new ConcurrentHashMap<>();

        @Inject
        private BeanManager beanManager;

        Invocation getOrCreateInvocation(final InvocationContext ic) {
            return providers.computeIfAbsent(ic.getMethod(), method -> {
                final Class declaringClass = method.getDeclaringClass();
                final AnnotatedType<Object> annotatedType = beanManager.createAnnotatedType(declaringClass);
                final Optional<AnnotatedMethod<? super Object>> annotatedMethod = annotatedType.getMethods().stream()
                    .filter(am -> am.getJavaMember().equals(method))
                    .findFirst();

                final Throttled config = annotatedMethod
                    .map(am -> am.getAnnotation(Throttled.class))
                    .orElseGet(() -> annotatedType.getAnnotation(Throttled.class));
                final Optional<Throttling> sharedConfig =
                    ofNullable(annotatedMethod.map(am -> am.getAnnotation(Throttling.class))
                        .orElseGet(() -> annotatedType.getAnnotation(Throttling.class)));

                final SemaphoreFactory factory = sharedConfig.map(Throttling::factory).filter(f -> f != SemaphoreFactory.class)
                    .map(c -> SemaphoreFactory.class.cast(beanManager.getReference(beanManager.resolve(beanManager.getBeans(c)), SemaphoreFactory.class, null)))
                    .orElse(this);

                final Semaphore semaphore = factory.newSemaphore(
                    annotatedMethod.orElseThrow(() -> new IllegalStateException("No annotated method for " + method)),
                    sharedConfig.map(Throttling::name).orElseGet(declaringClass::getName),
                    sharedConfig.map(Throttling::fair).orElse(false), sharedConfig.map(Throttling::permits).orElse(1));
                final long timeout = config.timeoutUnit().toMillis(config.timeout());
                final int weigth = config.weight();
                return new Invocation(semaphore, weigth, timeout);
            });
        }

        @Override
        public Semaphore newSemaphore(final AnnotatedMethod<?> method, final String name, final boolean fair, final int permits) {
            return semaphores.computeIfAbsent(name, key -> new Semaphore(permits, fair));
        }
    }

    private static final class Invocation {
        private final int weight;
        private final Semaphore semaphore;
        private final long timeout;

        private Invocation(final Semaphore semaphore, final int weight, final long timeout) {
            this.semaphore = semaphore;
            this.weight = weight;
            this.timeout = timeout;
        }

        Object invoke(final InvocationContext context) throws Exception {
            if (timeout > 0) {
                try {
                    if (!semaphore.tryAcquire(weight, timeout, TimeUnit.MILLISECONDS)) {
                        throw new IllegalStateException("Can't acquire " + weight + " permits for " + context.getMethod() + " in " + timeout + "ms");
                    }
                } catch (final InterruptedException e) {
                    return onInterruption(e);
                }
            } else {
                try {
                    semaphore.acquire(weight);
                } catch (final InterruptedException e) {
                    return onInterruption(e);
                }
            }
            try {
                return context.proceed();
            } finally {
                semaphore.release(weight);
            }
        }
    }
}

LocalCache is where we store the semphore for the application and how the invocation wrapping is done. Invocation is the actual wrapping of the method invocation. Then the interceptor itself is very simple: it uses the computed invocation for the method to wrap the method invocation with the access throttling.

Side note: the use of @Priority makes the interceptor added without any beans.xml change.

Usage

Usage is now pretty natural:

@RequestScoped
@Path("demo")
@Throttling(name = "fast-webservice", permits = 50)
public static class Service2 {
    private final Collection<String> called = new ArrayList<>();

    @GET
    @Throttled(timeout = 5, timeoutUnit = TimeUnit.SECONDS)
    public String get() {
        return "...";
    }
}

Potential enhancement

As most of annotation based solutions you can still enhance this solution with these features:

  • support configuration for the value: @Throttling(permits = “${my.permits.in.system.properties:50}”) for instance
  • JAX-RS integration: our implementation throws a IllegalStateException if the timeout occurs but in a real life you would create a ThrottlingException and its ExceptionMapper companion to return a proper response to the client if you use it for a JAX-RS endpoint

However this implementation is nice cause it is not a workaround for the throttling need and it is very few lines of code thanks to the Semaphore implementation provided in the JVM.

Leave a comment