diff --git a/pkl-server/src/main/kotlin/org/pkl/server/Server.kt b/pkl-server/src/main/kotlin/org/pkl/server/Server.kt index 7bc12f47..70dbb4d5 100644 --- a/pkl-server/src/main/kotlin/org/pkl/server/Server.kt +++ b/pkl-server/src/main/kotlin/org/pkl/server/Server.kt @@ -183,49 +183,54 @@ class Server(private val transport: MessageTransport) : AutoCloseable { } private fun createEvaluator(message: CreateEvaluatorRequest, evaluatorId: Long): BinaryEvaluator { - val modulePaths = message.modulePaths ?: emptyList() - val resolver = ModulePathResolver(modulePaths) - val allowedModules = message.allowedModules?.map { Pattern.compile(it) } ?: emptyList() - val allowedResources = message.allowedResources?.map { Pattern.compile(it) } ?: emptyList() - val rootDir = message.rootDir - val env = message.env ?: emptyMap() - val properties = message.properties ?: emptyMap() - val timeout = message.timeout - val cacheDir = message.cacheDir - val httpClient = - with(HttpClient.builder()) { - message.http?.proxy?.let { proxy -> - setProxy(proxy.address, proxy.noProxy ?: listOf()) - proxy.address?.let(IoUtils::setSystemProxy) - proxy.noProxy?.let { System.setProperty("http.nonProxyHosts", it.joinToString("|")) } + try { + val modulePaths = message.modulePaths ?: emptyList() + val resolver = ModulePathResolver(modulePaths) + val allowedModules = message.allowedModules?.map { Pattern.compile(it) } ?: emptyList() + val allowedResources = message.allowedResources?.map { Pattern.compile(it) } ?: emptyList() + val rootDir = message.rootDir + val env = message.env ?: emptyMap() + val properties = message.properties ?: emptyMap() + val timeout = message.timeout + val cacheDir = message.cacheDir + val httpClient = + with(HttpClient.builder()) { + message.http?.proxy?.let { proxy -> + setProxy(proxy.address, proxy.noProxy ?: listOf()) + proxy.address?.let(IoUtils::setSystemProxy) + proxy.noProxy?.let { System.setProperty("http.nonProxyHosts", it.joinToString("|")) } + } + message.http?.caCertificates?.let(::addCertificates) + message.http?.rewrites?.let(::setRewrites) + buildLazily() } - message.http?.caCertificates?.let(::addCertificates) - buildLazily() - } - val dependencies = - message.project?.let { proj -> - buildDeclaredDependencies(proj.projectFileUri, proj.dependencies, null) - } - log("Got dependencies: $dependencies") - return BinaryEvaluator( - StackFrameTransformers.defaultTransformer, - SecurityManagers.standard( - allowedModules, - allowedResources, - SecurityManagers.defaultTrustLevels, - rootDir, - ), - httpClient, - ClientLogger(evaluatorId, transport), - createModuleKeyFactories(message, evaluatorId, resolver), - createResourceReaders(message, evaluatorId, resolver), - env, - properties, - timeout, - cacheDir, - dependencies, - message.outputFormat, - ) + val dependencies = + message.project?.let { proj -> + buildDeclaredDependencies(proj.projectFileUri, proj.dependencies, null) + } + log("Got dependencies: $dependencies") + return BinaryEvaluator( + StackFrameTransformers.defaultTransformer, + SecurityManagers.standard( + allowedModules, + allowedResources, + SecurityManagers.defaultTrustLevels, + rootDir, + ), + httpClient, + ClientLogger(evaluatorId, transport), + createModuleKeyFactories(message, evaluatorId, resolver), + createResourceReaders(message, evaluatorId, resolver), + env, + properties, + timeout, + cacheDir, + dependencies, + message.outputFormat, + ) + } catch (e: IllegalArgumentException) { + throw ProtocolException(e.message ?: "Failed to create an evalutor. $e", e) + } } private fun createResourceReaders( diff --git a/pkl-server/src/main/kotlin/org/pkl/server/ServerMessagePackDecoder.kt b/pkl-server/src/main/kotlin/org/pkl/server/ServerMessagePackDecoder.kt index 48d1d7f7..51462d3d 100644 --- a/pkl-server/src/main/kotlin/org/pkl/server/ServerMessagePackDecoder.kt +++ b/pkl-server/src/main/kotlin/org/pkl/server/ServerMessagePackDecoder.kt @@ -99,8 +99,8 @@ class ServerMessagePackDecoder(unpacker: MessageUnpacker) : BaseMessagePackDecod getNullable(httpMap, "rewrites") ?.asMapValue() ?.map() - ?.mapKeys { it.key.asStringValue().asString() } - ?.mapValues { it.value.asStringValue().asString() } + ?.mapKeys { URI(it.key.asStringValue().asString()) } + ?.mapValues { URI(it.value.asStringValue().asString()) } return Http(caCertificates, proxy, rewrites) } diff --git a/pkl-server/src/main/kotlin/org/pkl/server/ServerMessagePackEncoder.kt b/pkl-server/src/main/kotlin/org/pkl/server/ServerMessagePackEncoder.kt index 25cc63b6..29658130 100644 --- a/pkl-server/src/main/kotlin/org/pkl/server/ServerMessagePackEncoder.kt +++ b/pkl-server/src/main/kotlin/org/pkl/server/ServerMessagePackEncoder.kt @@ -48,8 +48,8 @@ class ServerMessagePackEncoder(packer: MessagePacker) : BaseMessagePackEncoder(p packString("rewrites") packMapHeader(rewrites.size) for ((key, value) in rewrites) { - packString(key) - packString(value) + packString(key.toString()) + packString(value.toString()) } } } diff --git a/pkl-server/src/main/kotlin/org/pkl/server/ServerMessages.kt b/pkl-server/src/main/kotlin/org/pkl/server/ServerMessages.kt index 0da36bb0..139585c4 100644 --- a/pkl-server/src/main/kotlin/org/pkl/server/ServerMessages.kt +++ b/pkl-server/src/main/kotlin/org/pkl/server/ServerMessages.kt @@ -57,7 +57,7 @@ data class Http( /** Proxy settings */ val proxy: Proxy?, /** HTTP rewrites */ - val rewrites: Map?, + val rewrites: Map?, ) { override fun equals(other: Any?): Boolean { if (this === other) return true diff --git a/pkl-server/src/test/kotlin/org/pkl/server/AbstractServerTest.kt b/pkl-server/src/test/kotlin/org/pkl/server/AbstractServerTest.kt index 21a5c320..1a815ecf 100644 --- a/pkl-server/src/test/kotlin/org/pkl/server/AbstractServerTest.kt +++ b/pkl-server/src/test/kotlin/org/pkl/server/AbstractServerTest.kt @@ -60,6 +60,26 @@ abstract class AbstractServerTest { abstract val client: TestTransport + private val blankCreateEvaluatorRequest = + CreateEvaluatorRequest( + requestId = 1, + http = null, + allowedModules = null, + allowedResources = null, + clientModuleReaders = null, + clientResourceReaders = null, + modulePaths = null, + env = null, + properties = null, + timeout = null, + rootDir = null, + cacheDir = null, + outputFormat = null, + project = null, + externalModuleReaders = null, + externalResourceReaders = null, + ) + @Test fun `create and close evaluator`() { val evaluatorId = client.sendCreateEvaluatorRequest(123) @@ -931,6 +951,50 @@ abstract class AbstractServerTest { ) } + @Test + fun `http rewrites`() { + val evaluatorId = + client.sendCreateEvaluatorRequest( + http = + Http( + caCertificates = null, + proxy = null, + rewrites = mapOf(URI("https://example.com/") to URI("https://example.example/")), + ) + ) + client.send( + EvaluateRequest( + 1, + evaluatorId, + URI("repl:text"), + "res = import(\"https://example.com/foo.pkl\")", + "output.text", + ) + ) + val response = client.receive() + assertThat(response.error) + .contains( + "request was rewritten: https://example.com/foo.pkl -> https://example.example/foo.pkl" + ) + } + + @Test + fun `http rewrites -- invalid rule`() { + client.send( + blankCreateEvaluatorRequest.copy( + http = + Http( + caCertificates = null, + proxy = null, + rewrites = mapOf(URI("https://example.com") to URI("https://example.example/")), + ) + ) + ) + val response = client.receive() + assertThat(response.error) + .contains("Rewrite rule must end with '/', but was 'https://example.com'") + } + private val ByteArray.debugYaml get() = MessagePackDebugRenderer(this).output.trimIndent() diff --git a/pkl-server/src/test/kotlin/org/pkl/server/ServerMessagePackCodecTest.kt b/pkl-server/src/test/kotlin/org/pkl/server/ServerMessagePackCodecTest.kt index 20a4daf7..f0de03bd 100644 --- a/pkl-server/src/test/kotlin/org/pkl/server/ServerMessagePackCodecTest.kt +++ b/pkl-server/src/test/kotlin/org/pkl/server/ServerMessagePackCodecTest.kt @@ -96,7 +96,7 @@ class ServerMessagePackCodecTest { Http( proxy = Proxy(URI("http://foo.com:1234"), listOf("bar", "baz")), caCertificates = byteArrayOf(1, 2, 3, 4), - rewrites = mapOf("https://foo.com" to "https://bar.com"), + rewrites = mapOf(URI("https://foo.com/") to URI("https://bar.com/")), ), externalModuleReaders = mapOf("external" to externalReader, "external2" to externalReader), externalResourceReaders = mapOf("external" to externalReader),