mirror of
https://github.com/apple/pkl.git
synced 2026-04-21 16:01:31 +02:00
Fix data race in MessagePack encoder for concurrent server sends (#1486)
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved.
|
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@@ -41,10 +41,17 @@ public abstract class AbstractMessagePackEncoder implements MessageEncoder {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final void encode(Message msg) throws IOException, ProtocolException {
|
public final void encode(Message msg) throws IOException, ProtocolException {
|
||||||
packer.packArrayHeader(2);
|
// Serialize access to the packer. In pkl server mode the main thread
|
||||||
packer.packInt(msg.type().getCode());
|
// (handling CreateEvaluatorRequest) and the executor thread (sending
|
||||||
encodeMessage(msg);
|
// EvaluateResponse / ReadModuleRequest) call encode() concurrently.
|
||||||
packer.flush();
|
// Without this lock their writes interleave, corrupting the MessagePack
|
||||||
|
// stream. See JvmServerTest "concurrent encoding" for a regression test.
|
||||||
|
synchronized (packer) {
|
||||||
|
packer.packArrayHeader(2);
|
||||||
|
packer.packInt(msg.type().getCode());
|
||||||
|
encodeMessage(msg);
|
||||||
|
packer.flush();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void packMapHeader(int size, @Nullable Object value1) throws IOException {
|
protected void packMapHeader(int size, @Nullable Object value1) throws IOException {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved.
|
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@@ -15,10 +15,14 @@
|
|||||||
*/
|
*/
|
||||||
package org.pkl.server
|
package org.pkl.server
|
||||||
|
|
||||||
|
import java.io.PipedInputStream
|
||||||
|
import java.io.PipedOutputStream
|
||||||
import java.net.URI
|
import java.net.URI
|
||||||
import java.nio.file.Path
|
import java.nio.file.Path
|
||||||
|
import java.util.concurrent.CountDownLatch
|
||||||
import java.util.concurrent.ExecutorService
|
import java.util.concurrent.ExecutorService
|
||||||
import java.util.concurrent.Executors
|
import java.util.concurrent.Executors
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger
|
||||||
import kotlin.io.path.createDirectories
|
import kotlin.io.path.createDirectories
|
||||||
import kotlin.io.path.outputStream
|
import kotlin.io.path.outputStream
|
||||||
import kotlin.io.path.writeText
|
import kotlin.io.path.writeText
|
||||||
@@ -839,6 +843,100 @@ abstract class AbstractServerTest {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Regression test for concurrent message encoding.
|
||||||
|
*
|
||||||
|
* The pkl server's main thread sends [CreateEvaluatorResponse] while the executor thread sends
|
||||||
|
* [EvaluateResponse]. Without synchronization on the encoder, these writes interleave on the
|
||||||
|
* output stream, corrupting the MessagePack framing.
|
||||||
|
*
|
||||||
|
* This test exercises the race directly: two threads write different message types through the
|
||||||
|
* same [ServerMessagePackEncoder] into a pipe, and a reader thread decodes every message. Any
|
||||||
|
* interleaved write produces a decode error.
|
||||||
|
*
|
||||||
|
* Only meaningful with `USE_DIRECT_TRANSPORT = false` (the default).
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
fun `concurrent encoding -- multiple evaluators with module reads`() {
|
||||||
|
if (USE_DIRECT_TRANSPORT) return
|
||||||
|
|
||||||
|
val pipeIn = PipedInputStream(1 shl 20) // 1 MB buffer
|
||||||
|
val pipeOut = PipedOutputStream(pipeIn)
|
||||||
|
val encoder = ServerMessagePackEncoder(pipeOut)
|
||||||
|
val decoder = ServerMessagePackDecoder(pipeIn)
|
||||||
|
|
||||||
|
val iterations = 2000
|
||||||
|
val padding = ByteArray(8192) // large payload to widen the race window
|
||||||
|
val errors = mutableListOf<Throwable>()
|
||||||
|
val decoded = AtomicInteger(0)
|
||||||
|
val done = CountDownLatch(2)
|
||||||
|
|
||||||
|
// Writer A: CreateEvaluatorResponse (small messages)
|
||||||
|
val writerA = Thread {
|
||||||
|
try {
|
||||||
|
for (i in 0 until iterations) {
|
||||||
|
encoder.encode(CreateEvaluatorResponse(i.toLong(), i.toLong(), null))
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
synchronized(errors) { errors.add(e) }
|
||||||
|
} finally {
|
||||||
|
done.countDown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writer B: EvaluateResponse (large messages with 8 KB payload)
|
||||||
|
val writerB = Thread {
|
||||||
|
try {
|
||||||
|
for (i in 0 until iterations) {
|
||||||
|
encoder.encode(EvaluateResponse(i.toLong() + iterations, i.toLong(), padding, null))
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
synchronized(errors) { errors.add(e) }
|
||||||
|
} finally {
|
||||||
|
done.countDown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reader: decode all messages, check each is well-formed.
|
||||||
|
val reader = Thread {
|
||||||
|
try {
|
||||||
|
while (decoded.get() < iterations * 2) {
|
||||||
|
val msg = decoder.decode() ?: break
|
||||||
|
decoded.incrementAndGet()
|
||||||
|
when (msg) {
|
||||||
|
is CreateEvaluatorResponse -> {}
|
||||||
|
is EvaluateResponse -> {}
|
||||||
|
else ->
|
||||||
|
synchronized(errors) {
|
||||||
|
errors.add(AssertionError("Wrong message type: ${msg.javaClass.simpleName}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
synchronized(errors) { errors.add(e) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reader.start()
|
||||||
|
writerA.start()
|
||||||
|
writerB.start()
|
||||||
|
|
||||||
|
done.await(30, java.util.concurrent.TimeUnit.SECONDS)
|
||||||
|
pipeOut.close()
|
||||||
|
reader.join(10_000)
|
||||||
|
|
||||||
|
synchronized(errors) {
|
||||||
|
if (errors.isNotEmpty()) {
|
||||||
|
throw AssertionError(
|
||||||
|
"${errors.size} encoding errors (decoded ${decoded.get()}/${iterations * 2}): " +
|
||||||
|
errors.first().message,
|
||||||
|
errors.first(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertThat(decoded.get()).isEqualTo(iterations * 2)
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `evaluate with project dependencies`(@TempDir tempDir: Path) {
|
fun `evaluate with project dependencies`(@TempDir tempDir: Path) {
|
||||||
val cacheDir = tempDir.resolve("cache").createDirectories()
|
val cacheDir = tempDir.resolve("cache").createDirectories()
|
||||||
|
|||||||
Reference in New Issue
Block a user