Fix data race in MessagePack encoder for concurrent server sends (#1486)

This commit is contained in:
Luke Daley
2026-04-04 14:26:16 -07:00
committed by GitHub
parent 58033598c7
commit 8e7eb2bd96
2 changed files with 111 additions and 6 deletions

View File

@@ -1,5 +1,5 @@
/*
* Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved.
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,10 +15,14 @@
*/
package org.pkl.server
import java.io.PipedInputStream
import java.io.PipedOutputStream
import java.net.URI
import java.nio.file.Path
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
import kotlin.io.path.createDirectories
import kotlin.io.path.outputStream
import kotlin.io.path.writeText
@@ -839,6 +843,100 @@ abstract class AbstractServerTest {
)
}
/**
* Regression test for concurrent message encoding.
*
* The pkl server's main thread sends [CreateEvaluatorResponse] while the executor thread sends
* [EvaluateResponse]. Without synchronization on the encoder, these writes interleave on the
* output stream, corrupting the MessagePack framing.
*
* This test exercises the race directly: two threads write different message types through the
* same [ServerMessagePackEncoder] into a pipe, and a reader thread decodes every message. Any
* interleaved write produces a decode error.
*
* Only meaningful with `USE_DIRECT_TRANSPORT = false` (the default).
*/
@Test
fun `concurrent encoding -- multiple evaluators with module reads`() {
if (USE_DIRECT_TRANSPORT) return
val pipeIn = PipedInputStream(1 shl 20) // 1 MB buffer
val pipeOut = PipedOutputStream(pipeIn)
val encoder = ServerMessagePackEncoder(pipeOut)
val decoder = ServerMessagePackDecoder(pipeIn)
val iterations = 2000
val padding = ByteArray(8192) // large payload to widen the race window
val errors = mutableListOf<Throwable>()
val decoded = AtomicInteger(0)
val done = CountDownLatch(2)
// Writer A: CreateEvaluatorResponse (small messages)
val writerA = Thread {
try {
for (i in 0 until iterations) {
encoder.encode(CreateEvaluatorResponse(i.toLong(), i.toLong(), null))
}
} catch (e: Exception) {
synchronized(errors) { errors.add(e) }
} finally {
done.countDown()
}
}
// Writer B: EvaluateResponse (large messages with 8 KB payload)
val writerB = Thread {
try {
for (i in 0 until iterations) {
encoder.encode(EvaluateResponse(i.toLong() + iterations, i.toLong(), padding, null))
}
} catch (e: Exception) {
synchronized(errors) { errors.add(e) }
} finally {
done.countDown()
}
}
// Reader: decode all messages, check each is well-formed.
val reader = Thread {
try {
while (decoded.get() < iterations * 2) {
val msg = decoder.decode() ?: break
decoded.incrementAndGet()
when (msg) {
is CreateEvaluatorResponse -> {}
is EvaluateResponse -> {}
else ->
synchronized(errors) {
errors.add(AssertionError("Wrong message type: ${msg.javaClass.simpleName}"))
}
}
}
} catch (e: Exception) {
synchronized(errors) { errors.add(e) }
}
}
reader.start()
writerA.start()
writerB.start()
done.await(30, java.util.concurrent.TimeUnit.SECONDS)
pipeOut.close()
reader.join(10_000)
synchronized(errors) {
if (errors.isNotEmpty()) {
throw AssertionError(
"${errors.size} encoding errors (decoded ${decoded.get()}/${iterations * 2}): " +
errors.first().message,
errors.first(),
)
}
}
assertThat(decoded.get()).isEqualTo(iterations * 2)
}
@Test
fun `evaluate with project dependencies`(@TempDir tempDir: Path) {
val cacheDir = tempDir.resolve("cache").createDirectories()