Getting Stomp working using Spring was very easy. It didn’t seem that easy to me because I tested my implementation with an integration test which seemed to fail at least as much as it succeeded. In this post I will take you through my journey of getting a stable integration test on Spring Websockets.

Initial test setup

My example test is quite simple:

Listing 1. The test method
@Test
public void stompTest() throws Exception {
    final String message = "myMessage";
    messagingTemplate.convertAndSend(SUBSCRIPTION_TOPIC, message);
    final String response = receivedMessages.poll(5, SECONDS);
    Assert.assertEquals(message, response);
}

I send a simple String to a predefined topic using a SimpMessagingTemplate and I receive the message through a BlockingQueue(described below), after which I validate the received String.

Intercepting Messages

To catch the Stomp messages and add them to the BlockingQueue I have implemented my own StompSessionHandlerAdapter:

Listing 2. A custom implementation of StompSessionHandlerAdapter
private class MySessionHandler extends StompSessionHandlerAdapter {

    @Override
    public void afterConnected(StompSession session, StompHeaders connectedHeaders) {
        session.subscribe(SUBSCRIPTION_TOPIC, this);
    }

    @Override
    public void handleException(StompSession session, StompCommand command, StompHeaders headers, byte[] payload, Throwable exception) {
        LOGGER.warn("Stomp Error:", exception);
    }

    @Override
    public void handleTransportError(StompSession session, Throwable exception) {
        super.handleTransportError(session, exception);
        LOGGER.warn("Stomp Transport Error:", exception);
    }


    @Override
    public Type getPayloadType(StompHeaders headers) {
        return String.class;
    }


    @Override
    @SuppressWarnings("unchecked")
    public void handleFrame(StompHeaders stompHeaders, Object o) {
        LOGGER.info("Handle Frame with payload: {}", o);
        try {
            receivedMessages.offer((String) o, 500, MILLISECONDS);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }
}

As you can see, 5 methods are implemented.

  • afterConnected() makes sure the topic is subscribed to after the connection with the broker is established.

  • handleException() logs any exception that would otherwise be silently consumed.

  • handleTransportError() lets us know if anything goes wrong on the transport layer.

  • getPayloadType() returns the payload type which is used to pick the right message converter.

  • handleFrame() gets called everytime a message is received from one of the subscribed topics. In this implementation I add the received message to the receivedMessages BlockingQueue.

Setting up the connection

The Stomp client is created and connected in the @Before method of the test class:

Listing 3. The @Before method
@Before
public void setup() throws InterruptedException, ExecutionException, TimeoutException {
    final String URL = "ws://localhost:" + port + "/myendpoint";
    receivedMessages = new LinkedBlockingDeque<>();

    final List<Transport> transportList = Collections.singletonList(new WebSocketTransport(new StandardWebSocketClient()));
    final WebSocketStompClient stompClient = new WebSocketStompClient(new SockJsClient(transportList));
    stompClient.setMessageConverter(new StringMessageConverter());

    session = stompClient.connect(URL, new MySessionHandler()).get(5, SECONDS);
}

Here you can see how the WebSocketStompClient is created and connected. Notice how a new instance of the custom MySessionHandler is passed to the connect method. I also set an explicit StringMessageConverter. If you want to send anything different than Strings you should set the corresponding MessageConverter.

I also have a simple @After method to disconnect:

Listing 4. The @After method
@After
public void reset() {
    session.disconnect();
}

Initial results

If you run the test with this setup it sometimes succeeds, but also sometimes fails. We don’t want flaky tests! But why is this happening? That’s the same question I had. So I turned on DEBUG level logging on the org.springframework.messaging package. This is an excerpt from the test log:

Connection established in session id=4f400ea0-2adc-cfcb-5875-a24c62f2638e
Processing CONNECT session=7699e297703d4dbfb235966568d6c826
Processing MESSAGE destination=/topic/myTopic session=null payload=myMessage
Processing SUBSCRIBE /topic/myTopic id=0 session=7699e297703d4dbfb235966568d6c826
Connection closed in session id=4f400ea0-2adc-cfcb-5875-a24c62f2638e
Processing DISCONNECT session=7699e297703d4dbfb235966568d6c826
Processing DISCONNECT session=7699e297703d4dbfb235966568d6c826

As you can see the Processing MESSAGE comes before the Processing SUBSCRIBE. In other words, the message is sent before the subscription is handled.

WAIT WHAT?!?!

Didn’t I subscribe in the @Before method and send the message in the actual @Test?

Yes I did, but apparently the session.subscribe() method is an asynchronous process and the @Before class finishes and lets the @Test do its thing while the subscription is not yet handled.

Callbacks and CountDownLatches

We can fix this problem by simply adding a Thread.sleep(1000) at the end of our @Before method. But now every test takes at least a second. And what happens if our build pipeline has a bad day and takes 2 seconds to handle the subscription?

Let’s try to do this without the Thread.sleep().

This does not have to be hard, I just have to add a CountDownLatch in my @Before method, await the countdown at the end of my @Before method and actually countdown at some callback handler of the subscription.

Listing 5. The new @Before method
@Before
public void setup() throws InterruptedException, ExecutionException, TimeoutException {
    final String URL = "ws://localhost:" + port + "/myendpoint";
    receivedMessages = new LinkedBlockingDeque<>();

    final List<Transport> transportList = Collections.singletonList(new WebSocketTransport(new StandardWebSocketClient()));
    final WebSocketStompClient stompClient = new WebSocketStompClient(new SockJsClient(transportList));
    stompClient.setMessageConverter(new StringMessageConverter());

    subscriptionCountDownLatch = new CountDownLatch(1);
    session = stompClient.connect(URL, new MySessionHandler()).get(5, SECONDS);
    subscriptionCountDownLatch.await(10, SECONDS);
}

The CountDownLatch is in place and will only let the @Before method finish if it reaches zero or times out. Now we’ll have to call the countDown() method in some callback of the session.subscription()

Hmm, there doesn’t seem to be a way to check if the subscription is handled…

Oh, wait a minute! There is a SessionSubscribeEvent to which we can listen with an @EventListener method. This listener should be defined inside a Spring bean, so I’ll make it a @Service:

Listing 6. A subscription listener service
@Service
public class StompSubscriptionListener {
    private static final Logger LOGGER = LoggerFactory.getLogger(StompSubscriptionListener.class);
    private final List<Consumer<SessionSubscribeEvent>> callbacks = new ArrayList<>();

    public void registerCallback(final Consumer<SessionSubscribeEvent> callback) {
        callbacks.add(callback);
    }

    public void removeAllCallbacks() {
        callbacks.clear();
    }

    @EventListener
    public void handleSubscribeEvent(final SessionSubscribeEvent sessionSubscribeEvent) {
        callbacks.forEach(callback -> callback.accept(sessionSubscribeEvent));
    }
}

This service allows my test to add a callback on SessionSubscribeEvent events in which I can countdown my latch. Also don’t forget to remove the callbacks in the @After method:

Listing 7. The @Before method updated with an subscription callback
@Before
public void setup() throws InterruptedException, ExecutionException, TimeoutException {
    final String URL = "ws://localhost:" + port + "/myendpoint";
    receivedMessages = new LinkedBlockingDeque<>();

    final List<Transport> transportList = Collections.singletonList(new WebSocketTransport(new StandardWebSocketClient()));
    final WebSocketStompClient stompClient = new WebSocketStompClient(new SockJsClient(transportList));
    stompClient.setMessageConverter(new StringMessageConverter());

    subscriptionCountDownLatch = new CountDownLatch(1);
    subscriptionListener.registerCallback(e -> subscriptionCountDownLatch.countDown());
    session = stompClient.connect(URL, new MySessionHandler()).get(5, SECONDS);
    subscriptionCountDownLatch.await(10, SECONDS);
}
Listing 8. Do not forget to remove the callbacks
@After
public void reset() {
    subscriptionListener.removeAllCallbacks();
    session.disconnect();
}

Here I add the countdown as a callback. This should work!

…At least, that’s what I thought because my test was now succeeding many times in a row. Until it failed! Again! WHY?!?!

Back to sleep

Apparently the actual handling of the subscription is not tied to the SessionSubscribeEvent. Although the callback is fired around the same time the subscription is picked up for handling they do not have a determined order. So I have not eliminated the race condition, I have only vastly decreased the time window of the race condition.

I have searched for any other callbacks, events or even state I could use to make sure the test is in a correct working state before it leaves the @Before method, but I could not find anything.

So eventually I had to resort to the much threaded… I mean dreaded Thread.sleep(). This does not mean that the subscribe events are useless, because a smaller time window means less uncertainty in the time needed for the Thread.sleep() and a smaller amount of sleep time means less overhead per test.

Listing 9. Unfortunate sleep in the @Before method
@Before
public void setup() throws InterruptedException, ExecutionException, TimeoutException {
    final String URL = "ws://localhost:" + port + "/myendpoint";
    receivedMessages = new LinkedBlockingDeque<>();

    final List<Transport> transportList = Collections.singletonList(new WebSocketTransport(new StandardWebSocketClient()));
    final WebSocketStompClient stompClient = new WebSocketStompClient(new SockJsClient(transportList));
    stompClient.setMessageConverter(new StringMessageConverter());

    subscriptionCountDownLatch = new CountDownLatch(1);
    subscriptionListener.registerCallback(e -> subscriptionCountDownLatch.countDown());
    session = stompClient.connect(URL, new MySessionHandler()).get(5, SECONDS);
    subscriptionCountDownLatch.await(10, SECONDS);

    Thread.sleep(20);
}

I have run many tests and I’ve checked how many of those failed with different sleep times.

  • 0 ms (No Thread.sleep()): the test failed about 12% of the time.

  • 1 ms: the test failed 8% of the time.

  • 2+ ms: the test did not fail anymore.

I have increased the sleep time to 20ms just to be sure.

The session.disconnect() call in the @After method is also an asynchronous process which should be handled like I’ve handled the session.subscribe(). This can be done through intercepting the SessionDisconnectEvent.

And this is where I thought there was no better option. The evil Thread.sleep() had beaten me.

Or did it?

What about Awaitility?

With Awaitility it is possible to wait on a given condition to become true, but as I’ve mentioned, there is no state in the session or the client that can tell us whether the subscription is correctly registered.

Still, there is a way to check whether the client is correctly subscribed: simply check if it can receive messages. Let’s build this!

Listing 10. The conditional method to pass to awaitility
@Before
public void setup() throws InterruptedException, ExecutionException, TimeoutException {
    final String URL = "ws://localhost:" + port + "/myendpoint";
    receivedMessages = new LinkedBlockingDeque<>();

    final List<Transport> transportList = Collections.singletonList(new WebSocketTransport(new StandardWebSocketClient()));
    final WebSocketStompClient stompClient = new WebSocketStompClient(new SockJsClient(transportList));
    stompClient.setMessageConverter(new StringMessageConverter());

    subscriptionCountDownLatch = new CountDownLatch(1);
    subscriptionListener.registerCallback(e -> subscriptionCountDownLatch.countDown());
    session = stompClient.connect(URL, new MySessionHandler()).get(5, SECONDS);
    subscriptionCountDownLatch.await(10, SECONDS);

    Thread.sleep(20);
}

This method sends a message and tries to receive it. Once a message is received the subscription is working so we return true. But before we return true and pass the control to the test method, we must be sure there are no lingering messages in the queue. We do this by draining the queue until we have received our last sent message, that’s why we send a random UUID.

Now we can use the conditional method with Awaitility instead of our CountDownLatch and Thread.sleep():

Listing 11. No more Thread.sleep or CountDownLatch
@Before
public void setup() throws InterruptedException, ExecutionException, TimeoutException {
    final String URL = "ws://localhost:" + port + "/myendpoint";

    final List<Transport> transportList = Collections.singletonList(new WebSocketTransport(new StandardWebSocketClient()));
    final WebSocketStompClient stompClient = new WebSocketStompClient(new SockJsClient(transportList));
    stompClient.setMessageConverter(new StringMessageConverter());

    receivedMessages = new LinkedBlockingDeque<>();
    session = stompClient.connect(URL, new MySessionHandler()).get(5, SECONDS);
    await().until(this::isSubscribed);
}

This approach may seem silly in this case, because our @Before method implicitly tests what we wanted to test in our @Test method, but hopefully the actual tests you’ll create don’t just test if you can send and receive a message.

Conclusion

Spring Websockets can be a bit flaky to test, but with a good setup the tests become stable. And with Awaitility we don’t need callbacks and countdowns and most importantly we don’t need to sleep!

Resulting test class

A full example project can be found here: https://github.com/jcoreNL/spring-stomp-integration-test

Listing 12. The complete test class
@RunWith(SpringRunner.class)
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class BaseStompIntegrationTest {
    private static final String SUBSCRIPTION_TOPIC = "/topic/myTopic";
    private static final Logger LOGGER = LoggerFactory.getLogger(BaseStompIntegrationTest.class);

    private BlockingQueue<String> receivedMessages;
    private StompSession session;

    @Value("${local.server.port}")
    private int port;

    @Autowired
    private SimpMessagingTemplate messagingTemplate;

    @Before
    public void setup() throws InterruptedException, ExecutionException, TimeoutException {
        final String URL = "ws://localhost:" + port + "/myendpoint";

        final List<Transport> transportList = Collections.singletonList(new WebSocketTransport(new StandardWebSocketClient()));
        final WebSocketStompClient stompClient = new WebSocketStompClient(new SockJsClient(transportList));
        stompClient.setMessageConverter(new StringMessageConverter());

        receivedMessages = new LinkedBlockingDeque<>();
        session = stompClient.connect(URL, new MySessionHandler()).get(5, SECONDS);
        await().until(this::isSubscribed);
    }

    @Test
    public void stompTest() throws Exception {
        final String message = "myMessage";
        messagingTemplate.convertAndSend(SUBSCRIPTION_TOPIC, message);
        final String response = receivedMessages.poll(5, SECONDS);
        Assert.assertEquals(message, response);
    }

    @After
    public void reset() throws InterruptedException {
        session.disconnect();
        await().until(() -> !session.isConnected());
    }

    private boolean isSubscribed() {
        final String message = UUID.randomUUID().toString();

        messagingTemplate.convertAndSend(SUBSCRIPTION_TOPIC, message);

        String response = null;
        try {
            response = receivedMessages.poll(20, MILLISECONDS);

            // drain the message queue before returning true
            while(response != null && !message.equals(response)) {
                LOGGER.debug("Draining message queue");
                response = receivedMessages.poll(20, MILLISECONDS);
            }

        } catch (InterruptedException e) {
            LOGGER.debug("Polling received messages interrupted", e);
        }

        return response != null;
    }

    private class MySessionHandler extends StompSessionHandlerAdapter {
        @Override
        public void afterConnected(StompSession session, StompHeaders connectedHeaders) {
            session.subscribe(SUBSCRIPTION_TOPIC, this);
        }

        @Override
        public void handleException(StompSession session, StompCommand command, StompHeaders headers, byte[] payload, Throwable exception) {
            LOGGER.warn("Stomp Error:", exception);
        }

        @Override
        public void handleTransportError(StompSession session, Throwable exception) {
            super.handleTransportError(session, exception);
            LOGGER.warn("Stomp Transport Error:", exception);
        }

        @Override
        public Type getPayloadType(StompHeaders headers) {
            return String.class;
        }

        @Override
        @SuppressWarnings("unchecked")
        public void handleFrame(StompHeaders stompHeaders, Object o) {
            LOGGER.info("Handle Frame with payload: {}", o);
            try {
                receivedMessages.offer((String) o, 500, MILLISECONDS);
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
    }
}
shadow-left