Fix data race in MessagePack encoder for concurrent server sends (#1486)

This commit is contained in:
Luke Daley
2026-04-04 14:26:16 -07:00
committed by GitHub
parent 58033598c7
commit 8e7eb2bd96
2 changed files with 111 additions and 6 deletions

View File

@@ -1,5 +1,5 @@
/* /*
* Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved. * Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -41,10 +41,17 @@ public abstract class AbstractMessagePackEncoder implements MessageEncoder {
@Override @Override
public final void encode(Message msg) throws IOException, ProtocolException { public final void encode(Message msg) throws IOException, ProtocolException {
packer.packArrayHeader(2); // Serialize access to the packer. In pkl server mode the main thread
packer.packInt(msg.type().getCode()); // (handling CreateEvaluatorRequest) and the executor thread (sending
encodeMessage(msg); // EvaluateResponse / ReadModuleRequest) call encode() concurrently.
packer.flush(); // Without this lock their writes interleave, corrupting the MessagePack
// stream. See JvmServerTest "concurrent encoding" for a regression test.
synchronized (packer) {
packer.packArrayHeader(2);
packer.packInt(msg.type().getCode());
encodeMessage(msg);
packer.flush();
}
} }
protected void packMapHeader(int size, @Nullable Object value1) throws IOException { protected void packMapHeader(int size, @Nullable Object value1) throws IOException {

View File

@@ -1,5 +1,5 @@
/* /*
* Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved. * Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -15,10 +15,14 @@
*/ */
package org.pkl.server package org.pkl.server
import java.io.PipedInputStream
import java.io.PipedOutputStream
import java.net.URI import java.net.URI
import java.nio.file.Path import java.nio.file.Path
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
import kotlin.io.path.createDirectories import kotlin.io.path.createDirectories
import kotlin.io.path.outputStream import kotlin.io.path.outputStream
import kotlin.io.path.writeText import kotlin.io.path.writeText
@@ -839,6 +843,100 @@ abstract class AbstractServerTest {
) )
} }
/**
* Regression test for concurrent message encoding.
*
* The pkl server's main thread sends [CreateEvaluatorResponse] while the executor thread sends
* [EvaluateResponse]. Without synchronization on the encoder, these writes interleave on the
* output stream, corrupting the MessagePack framing.
*
* This test exercises the race directly: two threads write different message types through the
* same [ServerMessagePackEncoder] into a pipe, and a reader thread decodes every message. Any
* interleaved write produces a decode error.
*
* Only meaningful with `USE_DIRECT_TRANSPORT = false` (the default).
*/
@Test
fun `concurrent encoding -- multiple evaluators with module reads`() {
if (USE_DIRECT_TRANSPORT) return
val pipeIn = PipedInputStream(1 shl 20) // 1 MB buffer
val pipeOut = PipedOutputStream(pipeIn)
val encoder = ServerMessagePackEncoder(pipeOut)
val decoder = ServerMessagePackDecoder(pipeIn)
val iterations = 2000
val padding = ByteArray(8192) // large payload to widen the race window
val errors = mutableListOf<Throwable>()
val decoded = AtomicInteger(0)
val done = CountDownLatch(2)
// Writer A: CreateEvaluatorResponse (small messages)
val writerA = Thread {
try {
for (i in 0 until iterations) {
encoder.encode(CreateEvaluatorResponse(i.toLong(), i.toLong(), null))
}
} catch (e: Exception) {
synchronized(errors) { errors.add(e) }
} finally {
done.countDown()
}
}
// Writer B: EvaluateResponse (large messages with 8 KB payload)
val writerB = Thread {
try {
for (i in 0 until iterations) {
encoder.encode(EvaluateResponse(i.toLong() + iterations, i.toLong(), padding, null))
}
} catch (e: Exception) {
synchronized(errors) { errors.add(e) }
} finally {
done.countDown()
}
}
// Reader: decode all messages, check each is well-formed.
val reader = Thread {
try {
while (decoded.get() < iterations * 2) {
val msg = decoder.decode() ?: break
decoded.incrementAndGet()
when (msg) {
is CreateEvaluatorResponse -> {}
is EvaluateResponse -> {}
else ->
synchronized(errors) {
errors.add(AssertionError("Wrong message type: ${msg.javaClass.simpleName}"))
}
}
}
} catch (e: Exception) {
synchronized(errors) { errors.add(e) }
}
}
reader.start()
writerA.start()
writerB.start()
done.await(30, java.util.concurrent.TimeUnit.SECONDS)
pipeOut.close()
reader.join(10_000)
synchronized(errors) {
if (errors.isNotEmpty()) {
throw AssertionError(
"${errors.size} encoding errors (decoded ${decoded.get()}/${iterations * 2}): " +
errors.first().message,
errors.first(),
)
}
}
assertThat(decoded.get()).isEqualTo(iterations * 2)
}
@Test @Test
fun `evaluate with project dependencies`(@TempDir tempDir: Path) { fun `evaluate with project dependencies`(@TempDir tempDir: Path) {
val cacheDir = tempDir.resolve("cache").createDirectories() val cacheDir = tempDir.resolve("cache").createDirectories()