Improve HTTP redirect following (#1637)

This implements HTTP redirect following ourselves.

The goal is:

1. All I/O is checked against `--allowed-resources` and
`--allowed-modules`, including HTTP redirects
2. HTTP rewrite rules can affect redirect following
3. HTTP headers can affect redirect following

---------

Co-authored-by: Islon Scherer <islonscherer@gmail.com>
This commit is contained in:
Daniel Chao
2026-06-08 11:13:48 -07:00
committed by GitHub
parent b993cc3bb1
commit d012285f7d
36 changed files with 465 additions and 129 deletions
@@ -175,7 +175,7 @@ class LanguageSnippetTestsEngine : AbstractLanguageSnippetTestsEngine() {
.setHttpClient(
HttpClient.builder()
.setTestPort(packageServer.port)
.addCertificates(FileTestUtils.selfSignedCertificate)
.addCertificates(FileTestUtils.selfSignedCertificatePem)
.buildLazily()
)
.setPowerAssertionsEnabled(true)
@@ -287,7 +287,7 @@ abstract class AbstractNativeLanguageSnippetTestsEngine : AbstractLanguageSnippe
add("--settings")
add("pkl:settings")
add("--ca-certificates")
add(FileTestUtils.selfSignedCertificate.toString())
add(FileTestUtils.selfSignedCertificatePem.toString())
add("--test-mode")
add("--test-port")
add(packageServer.port.toString())
@@ -1,5 +1,5 @@
/*
* Copyright © 2024 Apple Inc. and the Pkl project authors. All rights reserved.
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -28,9 +28,13 @@ class DummyHttpClientTest {
val client = HttpClient.dummyClient()
val request = HttpRequest.newBuilder(URI("https://example.com")).build()
assertThrows<AssertionError> { client.send(request, HttpResponse.BodyHandlers.discarding()) }
assertThrows<AssertionError> {
client.send(request, HttpResponse.BodyHandlers.discarding(), NoopChecker)
}
assertThrows<AssertionError> { client.send(request, HttpResponse.BodyHandlers.discarding()) }
assertThrows<AssertionError> {
client.send(request, HttpResponse.BodyHandlers.discarding(), NoopChecker)
}
}
@Test
@@ -1,5 +1,5 @@
/*
* Copyright © 2024 Apple Inc. and the Pkl project authors. All rights reserved.
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,17 +15,31 @@
*/
package org.pkl.core.http
import com.github.tomakehurst.wiremock.client.WireMock.get
import com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor
import com.github.tomakehurst.wiremock.client.WireMock.matching
import com.github.tomakehurst.wiremock.client.WireMock.ok
import com.github.tomakehurst.wiremock.client.WireMock.permanentRedirect
import com.github.tomakehurst.wiremock.client.WireMock.stubFor
import com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo
import com.github.tomakehurst.wiremock.client.WireMock.verify
import com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig
import com.github.tomakehurst.wiremock.junit5.WireMockExtension
import java.net.URI
import java.net.http.HttpRequest
import java.net.http.HttpResponse
import java.nio.file.Path
import java.time.Duration
import kotlin.io.path.absolutePathString
import kotlin.io.path.createFile
import kotlin.io.path.readBytes
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatCode
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertDoesNotThrow
import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.extension.RegisterExtension
import org.junit.jupiter.api.io.TempDir
import org.pkl.commons.test.FileTestUtils
import org.pkl.core.Release
@@ -68,14 +82,16 @@ class HttpClientTest {
@Test
fun `can load certificates from regular file`() {
assertDoesNotThrow {
HttpClient.builder().addCertificates(FileTestUtils.selfSignedCertificate).build()
HttpClient.builder().addCertificates(FileTestUtils.selfSignedCertificatePem).build()
}
}
@Test
fun `can load certificates from a byte array`() {
assertDoesNotThrow {
HttpClient.builder().addCertificates(FileTestUtils.selfSignedCertificate.readBytes()).build()
HttpClient.builder()
.addCertificates(FileTestUtils.selfSignedCertificatePem.readBytes())
.build()
}
}
@@ -83,8 +99,7 @@ class HttpClientTest {
fun `certificate file cannot be empty`(@TempDir tempDir: Path) {
val file = tempDir.resolve("certs.pem").createFile()
val e =
assertThrows<HttpClientInitException> { HttpClient.builder().addCertificates(file).build() }
val e = assertThrows<HttpClientException> { HttpClient.builder().addCertificates(file).build() }
assertThat(e).hasMessageContaining("empty")
}
@@ -112,10 +127,185 @@ class HttpClientTest {
client.close()
assertThrows<IllegalStateException> {
client.send(request, HttpResponse.BodyHandlers.discarding())
client.send(request, HttpResponse.BodyHandlers.discarding(), NoopChecker)
}
assertThrows<IllegalStateException> {
client.send(request, HttpResponse.BodyHandlers.discarding())
client.send(request, HttpResponse.BodyHandlers.discarding(), NoopChecker)
}
}
@Nested
inner class RedirectsTest {
// incorrect diagnostic
@Suppress("JUnitMalformedDeclaration")
@RegisterExtension
val wireMock: WireMockExtension =
with(WireMockExtension.newInstance()) {
configureStaticDsl(true)
options(
wireMockConfig().apply {
dynamicPort()
dynamicHttpsPort()
keystorePath(FileTestUtils.selfSignedCertificateP12.absolutePathString())
keystorePassword(FileTestUtils.selfSignedCertificatePassword)
keystoreType("PKCS12")
}
)
build()
}
@Test
fun `follows redirects`() {
stubFor(get(urlEqualTo("/foo.pkl")).willReturn(permanentRedirect("/bar.pkl")))
stubFor(get(urlEqualTo("/bar.pkl")).willReturn(ok("bar = 1")))
val client = HttpClient.builder().build()
val request =
HttpRequest.newBuilder(URI("${wireMock.runtimeInfo.httpBaseUrl}/foo.pkl")).build()
val response = client.send(request, HttpResponse.BodyHandlers.ofString(), NoopChecker)
assert(response.body() == "bar = 1")
verify(getRequestedFor(urlEqualTo("/foo.pkl")))
verify(getRequestedFor(urlEqualTo("/bar.pkl")))
}
@Test
fun `preserves configured headers across redirects`() {
stubFor(get(urlEqualTo("/foo.pkl")).willReturn(permanentRedirect("/bar.pkl")))
stubFor(get(urlEqualTo("/bar.pkl")).willReturn(ok("bar = 1")))
val client =
HttpClient.builder().addHeaders("**", mapOf("x-foo" to listOf("foo value"))).build()
val request =
HttpRequest.newBuilder(URI("${wireMock.runtimeInfo.httpBaseUrl}/foo.pkl")).build()
val response = client.send(request, HttpResponse.BodyHandlers.ofString(), NoopChecker)
assert(response.body() == "bar = 1")
verify(getRequestedFor(urlEqualTo("/foo.pkl")).withHeader("x-foo", matching("foo value")))
verify(getRequestedFor(urlEqualTo("/bar.pkl")).withHeader("x-foo", matching("foo value")))
}
@Test
fun `respects configured rewrites across redirects`() {
stubFor(get(urlEqualTo("/foo.pkl")).willReturn(permanentRedirect("/orig/bar.pkl")))
stubFor(get(urlEqualTo("/rewritten/bar.pkl")).willReturn(ok()))
val client =
HttpClient.builder()
.addRewrite(
URI("${wireMock.runtimeInfo.httpBaseUrl}/orig/"),
URI("${wireMock.runtimeInfo.httpBaseUrl}/rewritten/"),
)
.build()
val request =
HttpRequest.newBuilder(URI("${wireMock.runtimeInfo.httpBaseUrl}/foo.pkl")).build()
client.send(request, HttpResponse.BodyHandlers.ofString(), NoopChecker)
verify(getRequestedFor(urlEqualTo("/foo.pkl")))
verify(getRequestedFor(urlEqualTo("/rewritten/bar.pkl")))
}
@Test
fun `cannot downgrade HTTPS to HTTP`() {
stubFor(
get(urlEqualTo("/foo.pkl"))
.willReturn(permanentRedirect("${wireMock.runtimeInfo.httpBaseUrl}/bar.pkl"))
)
val client =
HttpClient.builder()
.addCertificates(FileTestUtils.selfSignedCertificatePem)
.addHeaders("**", mapOf("x-foo" to listOf("foo value")))
.build()
val request =
HttpRequest.newBuilder(URI("${wireMock.runtimeInfo.httpsBaseUrl}/foo.pkl")).build()
assertThatCode { client.send(request, HttpResponse.BodyHandlers.ofString(), NoopChecker) }
.hasMessageContaining("Cannot follow redirect from 'https:' URL to 'http:' URL")
}
@Test
fun `can upgrade HTTP to HTTPS`() {
stubFor(
get(urlEqualTo("/foo.pkl"))
.willReturn(permanentRedirect("${wireMock.runtimeInfo.httpsBaseUrl}/bar.pkl"))
)
stubFor(get(urlEqualTo("/bar.pkl")).willReturn(ok("hello")))
val client =
HttpClient.builder()
.addCertificates(FileTestUtils.selfSignedCertificatePem)
.addHeaders("**", mapOf("x-foo" to listOf("foo value")))
.build()
val request =
HttpRequest.newBuilder(URI("${wireMock.runtimeInfo.httpBaseUrl}/foo.pkl")).build()
val response = client.send(request, HttpResponse.BodyHandlers.ofString(), NoopChecker)
assertThat(response.body()).isEqualTo("hello")
}
@Test
fun `infinite redirects fail with VmException`() {
stubFor(get(urlEqualTo("/foo.pkl")).willReturn(permanentRedirect("/bar.pkl")))
stubFor(get(urlEqualTo("/bar.pkl")).willReturn(permanentRedirect("/foo.pkl")))
val client = HttpClient.builder().build()
val request =
HttpRequest.newBuilder(URI("${wireMock.runtimeInfo.httpBaseUrl}/foo.pkl")).build()
assertThatCode { client.send(request, HttpResponse.BodyHandlers.ofString(), NoopChecker) }
.hasMessageContaining("Too many redirects")
verify(getRequestedFor(urlEqualTo("/foo.pkl")))
verify(getRequestedFor(urlEqualTo("/bar.pkl")))
}
@Test
fun `invalid redirect URI fails with VmException`() {
stubFor(get(urlEqualTo("/foo.pkl")).willReturn(permanentRedirect("http://not a valid url/")))
val client = HttpClient.builder().build()
val request =
HttpRequest.newBuilder(URI("${wireMock.runtimeInfo.httpBaseUrl}/foo.pkl")).build()
assertThatCode { client.send(request, HttpResponse.BodyHandlers.ofString(), NoopChecker) }
.hasMessageContaining(
"""
Cannot follow HTTP redirect because the response Location header has a malformed URI.
"""
.trimIndent()
)
verify(getRequestedFor(urlEqualTo("/foo.pkl")))
}
@Test
fun `checks each URL before making a request`() {
stubFor(get(urlEqualTo("/foo.pkl")).willReturn(permanentRedirect("/bar.pkl")))
stubFor(get(urlEqualTo("/bar.pkl")).willReturn(permanentRedirect("/qux.pkl")))
stubFor(get(urlEqualTo("/qux.pkl")).willReturn(ok()))
val checkedUrls = mutableListOf<URI>()
val checker = HttpClient.HttpRequestChecker { uri -> checkedUrls.add(uri) }
val client = HttpClient.builder().build()
val request =
HttpRequest.newBuilder(URI("${wireMock.runtimeInfo.httpBaseUrl}/foo.pkl")).build()
client.send(request, HttpResponse.BodyHandlers.ofString(), checker)
assertThat(checkedUrls).hasSize(3)
assertThat(checkedUrls)
.usingRecursiveComparison()
.isEqualTo(
listOf(
URI("${wireMock.runtimeInfo.httpBaseUrl}/foo.pkl"),
URI("${wireMock.runtimeInfo.httpBaseUrl}/bar.pkl"),
URI("${wireMock.runtimeInfo.httpBaseUrl}/qux.pkl"),
)
)
}
@Test
fun `redirects only carry their specifically configured headers`() {
stubFor(get(urlEqualTo("/foo.pkl")).willReturn(permanentRedirect("/bar.pkl")))
stubFor(get(urlEqualTo("/bar.pkl")).willReturn(ok()))
val request =
HttpRequest.newBuilder(URI("${wireMock.runtimeInfo.httpBaseUrl}/foo.pkl")).build()
val client =
with(HttpClient.builder()) {
addHeaders("**/foo.pkl", mapOf("x-foo" to listOf("foo value")))
addHeaders("**/bar.pkl", mapOf("x-bar" to listOf("bar value")))
build()
}
client.send(request, HttpResponse.BodyHandlers.discarding(), NoopChecker)
verify(getRequestedFor(urlEqualTo("/foo.pkl")).withHeader("x-foo", matching("foo value")))
verify(getRequestedFor(urlEqualTo("/bar.pkl")).withoutHeader("x-foo"))
verify(getRequestedFor(urlEqualTo("/bar.pkl")).withHeader("x-bar", matching("bar value")))
}
}
}
@@ -1,5 +1,5 @@
/*
* Copyright © 2024 Apple Inc. and the Pkl project authors. All rights reserved.
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -33,7 +33,9 @@ class LazyHttpClientTest {
val client = HttpClient.builder().addCertificates(certFile).buildLazily()
val request = HttpRequest.newBuilder(URI("https://example.com")).build()
assertThrows<HttpClientInitException> { client.send(request, BodyHandlers.discarding()) }
assertThrows<HttpClientException> {
client.send(request, BodyHandlers.discarding(), NoopChecker)
}
}
@Test
@@ -1,5 +1,5 @@
/*
* Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved.
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@ class RequestCapturingClient : HttpClient {
override fun <T : Any> send(
request: HttpRequest,
responseBodyHandler: HttpResponse.BodyHandler<T>,
httpRequestChecker: HttpClient.HttpRequestChecker,
): HttpResponse<T> {
this.request = request
return FakeHttpResponse()
@@ -45,7 +45,7 @@ class RequestRewritingClientTest {
@Test
fun `fills in missing User-Agent header`() {
client.send(exampleRequest, BodyHandlers.discarding())
client.send(exampleRequest, BodyHandlers.discarding(), NoopChecker)
assertThatList(captured.request.headers().allValues("User-Agent")).containsOnly("Pkl")
}
@@ -61,7 +61,7 @@ class RequestRewritingClientTest {
mapOf(URI("https://foo/") to URI("https://bar/")),
mapOf(IoUtils.doubleStarGlob to mapOf("User-Agent" to listOf("My-User-Agent"))),
)
client.send(exampleRequest, BodyHandlers.discarding())
client.send(exampleRequest, BodyHandlers.discarding(), NoopChecker)
assertThatList(captured.request.headers().allValues("User-Agent")).containsOnly("My-User-Agent")
}
@@ -73,14 +73,14 @@ class RequestRewritingClientTest {
.header("User-Agent", "Agent 2")
.build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThatList(captured.request.headers().allValues("User-Agent")).containsOnly("Pkl")
}
@Test
fun `fills in missing request timeout`() {
client.send(exampleRequest, BodyHandlers.discarding())
client.send(exampleRequest, BodyHandlers.discarding(), NoopChecker)
assertThat(captured.request.timeout()).hasValue(Duration.ofSeconds(42))
}
@@ -89,14 +89,14 @@ class RequestRewritingClientTest {
fun `leaves existing request timeout intact`() {
val request = HttpRequest.newBuilder(exampleUri).timeout(Duration.ofMinutes(33)).build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThat(captured.request.timeout()).hasValue(Duration.ofMinutes(33))
}
@Test
fun `fills in missing HTTP version`() {
client.send(exampleRequest, BodyHandlers.discarding())
client.send(exampleRequest, BodyHandlers.discarding(), NoopChecker)
assertThat(captured.request.version()).hasValue(JdkHttpClient.Version.HTTP_2)
}
@@ -105,7 +105,7 @@ class RequestRewritingClientTest {
fun `leaves existing HTTP version intact`() {
val request = HttpRequest.newBuilder(exampleUri).version(JdkHttpClient.Version.HTTP_1_1).build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThat(captured.request.version()).hasValue(JdkHttpClient.Version.HTTP_1_1)
}
@@ -114,7 +114,7 @@ class RequestRewritingClientTest {
fun `leaves default method intact`() {
val request = HttpRequest.newBuilder(exampleUri).build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThat(captured.request.method()).isEqualTo("GET")
}
@@ -123,7 +123,7 @@ class RequestRewritingClientTest {
fun `leaves explicit method intact`() {
val request = HttpRequest.newBuilder(exampleUri).DELETE().build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThat(captured.request.method()).isEqualTo("DELETE")
}
@@ -133,7 +133,7 @@ class RequestRewritingClientTest {
val publisher = BodyPublishers.ofString("body")
val request = HttpRequest.newBuilder(exampleUri).PUT(publisher).build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThat(captured.request.bodyPublisher().get()).isSameAs(publisher)
}
@@ -145,7 +145,7 @@ class RequestRewritingClientTest {
RequestRewritingClient("Pkl", Duration.ofSeconds(42), 5000, captured, mapOf(), mapOf())
val request = HttpRequest.newBuilder(URI("https://example.com:0")).build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThat(captured.request.uri().port).isEqualTo(5000)
}
@@ -154,7 +154,7 @@ class RequestRewritingClientTest {
fun `leaves port 0 intact if no test port is set`() {
val request = HttpRequest.newBuilder(URI("https://example.com:0")).build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThat(captured.request.uri().port).isEqualTo(0)
}
@@ -344,7 +344,7 @@ class RequestRewritingClientTest {
val captured = RequestCapturingClient()
val client = RequestRewritingClient("Pkl", Duration.ofSeconds(42), -1, captured, rules, mapOf())
val request = HttpRequest.newBuilder(URI(uri)).build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
return captured.request.uri().toString()
}
@@ -366,7 +366,7 @@ class RequestRewritingClientTest {
)
val request = HttpRequest.newBuilder(URI("https://example.com/foo/bar")).build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThatList(captured.request.headers().allValues("x-one")).containsExactly("one")
assertThatList(captured.request.headers().allValues("x-two")).containsExactly("two-a", "two-b")
@@ -389,7 +389,7 @@ class RequestRewritingClientTest {
)
val request = HttpRequest.newBuilder(URI("https://example.com/foo/bar")).build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThat(captured.request.headers().firstValue("x-foo")).isEmpty
assertThat(captured.request.headers().firstValue("x-bar")).isEmpty
@@ -413,7 +413,7 @@ class RequestRewritingClientTest {
val request =
HttpRequest.newBuilder(URI("https://example.com/foo/bar")).header("x-foo", "request").build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThatList(captured.request.headers().allValues("x-foo"))
.containsExactly("request", "rule-a", "rule-b")
@@ -436,7 +436,7 @@ class RequestRewritingClientTest {
)
val request = HttpRequest.newBuilder(URI("https://example.com/foo/bar")).build()
client.send(request, BodyHandlers.discarding())
client.send(request, BodyHandlers.discarding(), NoopChecker)
assertThatList(captured.request.headers().allValues("user-agent"))
.containsExactly("My User Agent")
@@ -28,3 +28,7 @@ fun HttpClient.getConfiguredSettings(): HttpSettings {
val requestRewritingClient = this.orCreateClient as RequestRewritingClient
return HttpSettings(requestRewritingClient.headers, requestRewritingClient.rewritesMap)
}
object NoopChecker : HttpClient.HttpRequestChecker {
override fun check(uri: URI) {}
}
@@ -49,7 +49,7 @@ class PackageResolversTest {
val httpClient: HttpClient by lazy {
HttpClient.builder()
.addCertificates(FileTestUtils.selfSignedCertificate)
.addCertificates(FileTestUtils.selfSignedCertificatePem)
.setTestPort(packageServer.port)
.build()
}
@@ -1,5 +1,5 @@
/*
* Copyright © 2024 Apple Inc. and the Pkl project authors. All rights reserved.
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -41,7 +41,7 @@ class ProjectDependenciesResolverTest {
val httpClient: HttpClient by lazy {
HttpClient.builder()
.addCertificates(FileTestUtils.selfSignedCertificate)
.addCertificates(FileTestUtils.selfSignedCertificatePem)
.setTestPort(packageServer.port)
.build()
}
@@ -227,7 +227,7 @@ class ProjectTest {
val project = Project.loadFromPath(projectDir.resolve("PklProject"))
val httpClient =
HttpClient.builder()
.addCertificates(FileTestUtils.selfSignedCertificate)
.addCertificates(FileTestUtils.selfSignedCertificatePem)
.setTestPort(server.port)
.build()
val evaluator =