mirror of
https://github.com/apple/pkl.git
synced 2026-04-10 10:53:40 +02:00
Fix data race in MessagePack encoder for concurrent server sends (#1486)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved.
|
||||
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -15,10 +15,14 @@
|
||||
*/
|
||||
package org.pkl.server
|
||||
|
||||
import java.io.PipedInputStream
|
||||
import java.io.PipedOutputStream
|
||||
import java.net.URI
|
||||
import java.nio.file.Path
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import kotlin.io.path.createDirectories
|
||||
import kotlin.io.path.outputStream
|
||||
import kotlin.io.path.writeText
|
||||
@@ -839,6 +843,100 @@ abstract class AbstractServerTest {
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Regression test for concurrent message encoding.
|
||||
*
|
||||
* The pkl server's main thread sends [CreateEvaluatorResponse] while the executor thread sends
|
||||
* [EvaluateResponse]. Without synchronization on the encoder, these writes interleave on the
|
||||
* output stream, corrupting the MessagePack framing.
|
||||
*
|
||||
* This test exercises the race directly: two threads write different message types through the
|
||||
* same [ServerMessagePackEncoder] into a pipe, and a reader thread decodes every message. Any
|
||||
* interleaved write produces a decode error.
|
||||
*
|
||||
* Only meaningful with `USE_DIRECT_TRANSPORT = false` (the default).
|
||||
*/
|
||||
@Test
|
||||
fun `concurrent encoding -- multiple evaluators with module reads`() {
|
||||
if (USE_DIRECT_TRANSPORT) return
|
||||
|
||||
val pipeIn = PipedInputStream(1 shl 20) // 1 MB buffer
|
||||
val pipeOut = PipedOutputStream(pipeIn)
|
||||
val encoder = ServerMessagePackEncoder(pipeOut)
|
||||
val decoder = ServerMessagePackDecoder(pipeIn)
|
||||
|
||||
val iterations = 2000
|
||||
val padding = ByteArray(8192) // large payload to widen the race window
|
||||
val errors = mutableListOf<Throwable>()
|
||||
val decoded = AtomicInteger(0)
|
||||
val done = CountDownLatch(2)
|
||||
|
||||
// Writer A: CreateEvaluatorResponse (small messages)
|
||||
val writerA = Thread {
|
||||
try {
|
||||
for (i in 0 until iterations) {
|
||||
encoder.encode(CreateEvaluatorResponse(i.toLong(), i.toLong(), null))
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
synchronized(errors) { errors.add(e) }
|
||||
} finally {
|
||||
done.countDown()
|
||||
}
|
||||
}
|
||||
|
||||
// Writer B: EvaluateResponse (large messages with 8 KB payload)
|
||||
val writerB = Thread {
|
||||
try {
|
||||
for (i in 0 until iterations) {
|
||||
encoder.encode(EvaluateResponse(i.toLong() + iterations, i.toLong(), padding, null))
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
synchronized(errors) { errors.add(e) }
|
||||
} finally {
|
||||
done.countDown()
|
||||
}
|
||||
}
|
||||
|
||||
// Reader: decode all messages, check each is well-formed.
|
||||
val reader = Thread {
|
||||
try {
|
||||
while (decoded.get() < iterations * 2) {
|
||||
val msg = decoder.decode() ?: break
|
||||
decoded.incrementAndGet()
|
||||
when (msg) {
|
||||
is CreateEvaluatorResponse -> {}
|
||||
is EvaluateResponse -> {}
|
||||
else ->
|
||||
synchronized(errors) {
|
||||
errors.add(AssertionError("Wrong message type: ${msg.javaClass.simpleName}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
synchronized(errors) { errors.add(e) }
|
||||
}
|
||||
}
|
||||
|
||||
reader.start()
|
||||
writerA.start()
|
||||
writerB.start()
|
||||
|
||||
done.await(30, java.util.concurrent.TimeUnit.SECONDS)
|
||||
pipeOut.close()
|
||||
reader.join(10_000)
|
||||
|
||||
synchronized(errors) {
|
||||
if (errors.isNotEmpty()) {
|
||||
throw AssertionError(
|
||||
"${errors.size} encoding errors (decoded ${decoded.get()}/${iterations * 2}): " +
|
||||
errors.first().message,
|
||||
errors.first(),
|
||||
)
|
||||
}
|
||||
}
|
||||
assertThat(decoded.get()).isEqualTo(iterations * 2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `evaluate with project dependencies`(@TempDir tempDir: Path) {
|
||||
val cacheDir = tempDir.resolve("cache").createDirectories()
|
||||
|
||||
Reference in New Issue
Block a user