Optimize let expressions as frame slot vars (#1622)

This introduces an optimization to write let expressions to frame slots
of the current frame, rather than transforming them into lambdas. As a
result, let expressions are _much_ faster to execute.
This commit is contained in:
Daniel Chao
2026-06-01 16:53:30 -07:00
committed by GitHub
parent bfda4cc8c8
commit 7dd2bc67de
11 changed files with 191 additions and 139 deletions
@@ -54,7 +54,7 @@ import org.pkl.core.ast.builder.SymbolTable.AnnotationScope;
import org.pkl.core.ast.builder.SymbolTable.ClassScope;
import org.pkl.core.ast.builder.SymbolTable.ModuleScope;
import org.pkl.core.ast.builder.SymbolTable.ObjectScope;
import org.pkl.core.ast.builder.VariableResolution.ForGeneratorVariable;
import org.pkl.core.ast.builder.VariableResolution.ForGeneratorOrLetVariable;
import org.pkl.core.ast.builder.VariableResolution.ImplicitBaseProperty;
import org.pkl.core.ast.builder.VariableResolution.ImplicitThisProperty;
import org.pkl.core.ast.builder.VariableResolution.LexicalProperty;
@@ -67,7 +67,7 @@ import org.pkl.core.ast.expression.binary.GreaterThanNodeGen;
import org.pkl.core.ast.expression.binary.GreaterThanOrEqualNodeGen;
import org.pkl.core.ast.expression.binary.LessThanNodeGen;
import org.pkl.core.ast.expression.binary.LessThanOrEqualNodeGen;
import org.pkl.core.ast.expression.binary.LetExprNode;
import org.pkl.core.ast.expression.binary.LetExprNodeGen;
import org.pkl.core.ast.expression.binary.LogicalAndNodeGen;
import org.pkl.core.ast.expression.binary.LogicalOrNodeGen;
import org.pkl.core.ast.expression.binary.MultiplicationNodeGen;
@@ -121,7 +121,6 @@ import org.pkl.core.ast.expression.member.ReadLocalPropertyNode;
import org.pkl.core.ast.expression.member.ReadPropertyNodeGen;
import org.pkl.core.ast.expression.member.ReadSuperEntryNode;
import org.pkl.core.ast.expression.member.ReadSuperPropertyNode;
import org.pkl.core.ast.expression.primary.GetEnclosingOwnerNode;
import org.pkl.core.ast.expression.primary.GetEnclosingReceiverNode;
import org.pkl.core.ast.expression.primary.GetMemberKeyNode;
import org.pkl.core.ast.expression.primary.GetModuleNode;
@@ -683,10 +682,10 @@ public class AstBuilder extends AbstractAstBuilder<Object> {
MemberLookupMode.IMPLICIT_LEXICAL,
needsConst,
p.levelsUp() == 0 ? new GetReceiverNode() : new GetEnclosingReceiverNode(p.levelsUp()));
} else if (resolution instanceof ForGeneratorVariable p) {
} else if (resolution instanceof ForGeneratorOrLetVariable p) {
// Parameters can possibly write to frame slots actually in a frame that is one level
// higher than what we can tell at parse time. However, for generator variables always
// write to frame slots in the same frame.
// higher than what we can tell at parse time. However, let exprs and for generator variables
// always write to frame slots in the same frame.
//
// function foo(bar) = new Mixin {
// [bar] = 1 <--- actually 1 level, not 0
@@ -891,29 +890,19 @@ public class AstBuilder extends AbstractAstBuilder<Object> {
Node child = expr;
var parent = expr.parent();
var scope = symbolTable.getCurrentScope();
var levelsUp = 0;
while (parent instanceof IfExpr
|| parent instanceof TraceExpr
|| parent instanceof LetExpr letExpr && letExpr.getExpr() == child) {
if (parent instanceof LetExpr) {
assert scope != null;
scope = scope.getParent();
levelsUp += 1;
}
child = parent;
parent = parent.parent();
}
assert scope != null;
if (parent instanceof ClassProperty || parent instanceof ObjectProperty) {
inferredParentNode =
InferParentWithinPropertyNodeGen.create(
createSourceSection(expr.newSpan()),
scope.getName(),
levelsUp == 0 ? new GetOwnerNode() : new GetEnclosingOwnerNode(levelsUp));
createSourceSection(expr.newSpan()), scope.getName(), new GetOwnerNode());
} else if (parent instanceof ObjectElement
|| parent instanceof ObjectEntry objectEntry && objectEntry.getValue() == child) {
inferredParentNode =
@@ -921,7 +910,7 @@ public class AstBuilder extends AbstractAstBuilder<Object> {
ReadPropertyNodeGen.create(
createSourceSection(expr.newSpan()),
org.pkl.core.runtime.Identifier.DEFAULT,
levelsUp == 0 ? new GetReceiverNode() : new GetEnclosingReceiverNode(levelsUp)),
new GetReceiverNode()),
new GetMemberKeyNode());
} else if (parent instanceof ClassMethod || parent instanceof ObjectMethod) {
var isObjectMethod =
@@ -931,18 +920,11 @@ public class AstBuilder extends AbstractAstBuilder<Object> {
inferredParentNode =
isObjectMethod
? new InferParentWithinObjectMethodNode(
createSourceSection(expr.newSpan()),
language,
scopeName,
levelsUp == 0 ? new GetOwnerNode() : new GetEnclosingOwnerNode(levelsUp))
createSourceSection(expr.newSpan()), language, scopeName, new GetOwnerNode())
: new InferParentWithinMethodNode(
createSourceSection(expr.newSpan()),
language,
scopeName,
levelsUp == 0 ? new GetOwnerNode() : new GetEnclosingOwnerNode(levelsUp));
createSourceSection(expr.newSpan()), language, scopeName, new GetOwnerNode());
} else if (parent instanceof LetExpr letExpr && letExpr.getBindingExpr() == child) {
// TODO (unclear how to infer type now that let-expression is implemented as lambda
// invocation)
// TODO correctly infer parent, e.g. `let (x: Person = new {}) ...`
throw exceptionBuilder()
.evalError("cannotInferParent")
.withSourceSection(createSourceSection(expr.newSpan()))
@@ -1113,38 +1095,28 @@ public class AstBuilder extends AbstractAstBuilder<Object> {
public ExpressionNode visitLetExpr(LetExpr letExpr) {
var sourceSection = createSourceSection(letExpr);
var parameter = letExpr.getParameter();
var frameBuilder = new FrameDescriptorBuilder();
UnresolvedTypeNode[] typeNodes;
var bindings = new ArrayList<String>();
UnresolvedTypeNode typeNode = null;
String binding = null;
var slot = -1;
var frameDescriptorBuilder = symbolTable.getCurrentScope().frameDescriptorBuilder;
if (parameter instanceof TypedIdentifier par) {
typeNodes = new UnresolvedTypeNode[] {visitTypeAnnotation(par.getTypeAnnotation())};
frameBuilder.addSlot(
FrameSlotKind.Illegal, toIdentifier(par.getIdentifier().getValue()), null);
bindings.add(par.getIdentifier().getValue());
} else {
typeNodes = new UnresolvedTypeNode[0];
typeNode = visitTypeAnnotation(par.getTypeAnnotation());
slot =
frameDescriptorBuilder.addSlot(
FrameSlotKind.Illegal, toIdentifier(par.getIdentifier().getValue()), null);
binding = par.getIdentifier().getValue();
}
var isCustomThisScope = symbolTable.getCurrentScope().isCustomThisScope();
UnresolvedFunctionNode functionNode =
symbolTable.enterLambda(
bindings,
frameBuilder,
scope -> {
var expr = visitExpr(letExpr.getExpr());
return new UnresolvedFunctionNode(
language,
scope.buildFrameDescriptor(),
new Lambda(createSourceSection(letExpr.getExpr()), scope.getQualifiedName()),
1,
typeNodes,
null,
expr);
});
return new LetExprNode(
sourceSection, functionNode, visitExpr(letExpr.getBindingExpr()), isCustomThisScope);
var bindingExpr = visitExpr(letExpr.getBindingExpr());
var t = typeNode;
var s = slot;
return symbolTable.enterLetExpression(
binding,
slot,
scope -> {
var bodyExpr = visitExpr(letExpr.getExpr());
return LetExprNodeGen.create(
sourceSection, scope.getQualifiedName(), t, bodyExpr, s, bindingExpr);
});
}
@Override
@@ -2159,6 +2131,10 @@ public class AstBuilder extends AbstractAstBuilder<Object> {
throw PklBugException.unreachableCode();
}
public FrameDescriptor buildModuleFrameDescriptor() {
return symbolTable.getCurrentScope().buildFrameDescriptor();
}
private ResolveDeclaredTypeNode doVisitTypeName(QualifiedIdentifier ctx) {
var identifiers = ctx.getIdentifiers();
return switch (identifiers.size()) {
@@ -26,6 +26,7 @@ import org.pkl.core.ast.VmModifier;
import org.pkl.core.ast.builder.MethodResolution.ImplicitBaseMethod;
import org.pkl.core.ast.builder.MethodResolution.ImplicitThisMethod;
import org.pkl.core.ast.builder.MethodResolution.LexicalMethod;
import org.pkl.core.ast.builder.VariableResolution.ForGeneratorOrLetVariable;
import org.pkl.core.ast.builder.VariableResolution.ImplicitBaseProperty;
import org.pkl.core.ast.builder.VariableResolution.LexicalProperty;
import org.pkl.core.ast.member.ObjectMember;
@@ -142,6 +143,20 @@ public final class SymbolTable {
nodeFactory);
}
public <T> T enterLetExpression(
@Nullable String binding, int slot, Function<LetExpressionScope, T> nodeFactory) {
// flatten names of let exprs inside other let exprs for presentation purposes
var parentScope = currentScope;
while (parentScope instanceof LetExpressionScope) {
parentScope = parentScope.getParent();
}
assert parentScope != null;
var qualifiedName = parentScope.qualifiedName + "." + "<let expr>";
return doEnter(new LetExpressionScope(currentScope, binding, slot, qualifiedName), nodeFactory);
}
public <T> T enterProperty(
Identifier name, ConstLevel constLevel, Function<PropertyScope, T> nodeFactory) {
return doEnter(
@@ -339,6 +354,19 @@ public final class SymbolTable {
return curr;
}
public final Scope skipLambdaAndLetScopes() {
var curr = this;
while (curr.isLambdaScope() || curr.isLetScope()) {
curr = curr.getParent();
assert curr != null : "Lambda scope always has a parent";
}
return curr;
}
public final boolean isLetScope() {
return this instanceof LetExpressionScope;
}
public final boolean isModuleScope() {
return this instanceof ModuleScope;
}
@@ -348,7 +376,7 @@ public final class SymbolTable {
}
public final boolean isClassMemberScope() {
var effectiveScope = skipLambdaScopes();
var effectiveScope = skipLambdaAndLetScopes();
var parent = effectiveScope.parent;
if (parent == null) return false;
@@ -454,8 +482,10 @@ public final class SymbolTable {
}
var result = fun.apply(lex, levelsUp);
if (result != null) return result;
if (scope instanceof MethodScope || scope instanceof ForGeneratorScope) {
// fors and methods don't level up
if (scope instanceof MethodScope
|| scope instanceof ForGeneratorScope
|| scope instanceof LetExpressionScope) {
// fors, methods, and let exprs don't level up
continue;
}
levelsUp++;
@@ -640,6 +670,46 @@ public final class SymbolTable {
}
}
public static final class LetExpressionScope extends Scope implements LexicalScope {
private final @Nullable String binding;
private final int slot;
private static @Nullable Identifier getParentName(Scope parent) {
while (parent != null && parent.name == null) {
parent = parent.getParent();
}
return parent == null ? null : parent.name;
}
public LetExpressionScope(
Scope parent, @Nullable String binding, int slot, String qualifiedName) {
super(
parent,
getParentName(parent),
qualifiedName,
parent.getConstLevel(),
parent.frameDescriptorBuilder);
this.binding = binding;
this.slot = slot;
}
@Override
public @Nullable VariableResolution doResolveProperty(String name, int levelsUp) {
if (name.equals("_")) {
return null;
}
if (name.equals(binding)) {
return new ForGeneratorOrLetVariable(slot, levelsUp);
}
return null;
}
@Override
public @Nullable MethodResolution doResolveMethod(String name, int levelsUp) {
return null;
}
}
// A generator scope that is resolved eagerly and one level above
public static final class EagerGeneratorScope extends Scope {
private static FrameDescriptorBuilder getFrameDescriptorBuilder(Scope parent) {
@@ -679,7 +749,7 @@ public final class SymbolTable {
}
var index = frameDescriptorBuilder.findSlot(Identifier.get(name));
if (index >= 0) {
return new VariableResolution.ForGeneratorVariable(index, levelsUp);
return new ForGeneratorOrLetVariable(index, levelsUp);
}
return null;
}
@@ -31,10 +31,10 @@ public sealed interface VariableResolution {
}
}
// let, lambda, object body param
// method, lambda, object body param
record Parameter(int slot, int levelsUp) implements VariableResolution {}
record ForGeneratorVariable(int slot, int levelsUp) implements VariableResolution {}
record ForGeneratorOrLetVariable(int slot, int levelsUp) implements VariableResolution {}
// Implicit base module lookup
record ImplicitBaseProperty() implements VariableResolution {}
@@ -1,5 +1,5 @@
/*
* Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved.
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,60 +16,68 @@
package org.pkl.core.ast.expression.binary;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.DirectCallNode;
import com.oracle.truffle.api.source.SourceSection;
import org.jspecify.annotations.Nullable;
import org.pkl.core.ast.ExpressionNode;
import org.pkl.core.ast.member.FunctionNode;
import org.pkl.core.ast.member.UnresolvedFunctionNode;
import org.pkl.core.runtime.VmFunction;
import org.pkl.core.ast.type.TypeNode;
import org.pkl.core.ast.type.UnresolvedTypeNode;
import org.pkl.core.runtime.VmException;
import org.pkl.core.runtime.VmUtils;
import org.pkl.core.util.LateInit;
public final class LetExprNode extends ExpressionNode {
private @Child UnresolvedFunctionNode unresolvedFunctionNode;
private @Child ExpressionNode valueNode;
private final boolean isCustomThisScope;
@NodeChild(value = "bindingNode", type = ExpressionNode.class)
public abstract class LetExprNode extends ExpressionNode {
@CompilationFinal @LateInit private FunctionNode functionNode;
@Child @LateInit private DirectCallNode callNode;
@CompilationFinal private int customThisSlot = -1;
private final String qualifiedName;
private @Child @Nullable UnresolvedTypeNode unresolvedTypeNode;
private @Child ExpressionNode bodyNode;
private @Child @Nullable TypeNode typeNode;
private final int slot;
public LetExprNode(
protected LetExprNode(
SourceSection sourceSection,
UnresolvedFunctionNode functionNode,
ExpressionNode valueNode,
boolean isCustomThisScope) {
String qualifiedName,
@Nullable UnresolvedTypeNode unresolvedTypeNode,
ExpressionNode bodyNode,
int slot) {
super(sourceSection);
this.unresolvedFunctionNode = functionNode;
this.valueNode = valueNode;
this.isCustomThisScope = isCustomThisScope;
this.qualifiedName = qualifiedName;
this.unresolvedTypeNode = unresolvedTypeNode;
this.bodyNode = bodyNode;
this.slot = slot;
}
@Override
public Object executeGeneric(VirtualFrame frame) {
if (functionNode == null) {
private TypeNode getTypeNode(VirtualFrame frame) {
if (typeNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
functionNode = unresolvedFunctionNode.execute(frame);
callNode = insert(DirectCallNode.create(functionNode.getCallTarget()));
if (isCustomThisScope) {
// deferred until execution time s.t. nodes of inlined type aliases get the right frame slot
customThisSlot = VmUtils.findCustomThisSlot(frame);
if (unresolvedTypeNode != null) {
typeNode = unresolvedTypeNode.execute(frame);
} else {
typeNode = new TypeNode.UnknownTypeNode(VmUtils.unavailableSourceSection());
}
typeNode.initWriteSlotNode(slot);
insert(typeNode);
}
assert typeNode != null;
return typeNode;
}
var function =
new VmFunction(
frame.materialize(),
isCustomThisScope ? frame.getAuxiliarySlot(customThisSlot) : VmUtils.getReceiver(frame),
1,
functionNode,
null);
var value = valueNode.executeGeneric(frame);
return callNode.call(function.getThisValue(), function, value);
@Specialization
protected Object eval(VirtualFrame frame, Object value) {
if (slot != -1) {
getTypeNode(frame).executeAndSet(frame, value);
}
try {
return bodyNode.executeGeneric(frame);
} catch (VmException e) {
CompilerDirectives.transferToInterpreter();
e.getInsertedStackFrames()
.put(
getRootNode().getCallTarget(),
VmUtils.createStackFrame(getSourceSection(), qualifiedName));
throw e;
}
}
}
@@ -961,12 +961,16 @@ public final class VmUtils {
resolvedModule,
false);
var language = VmLanguage.get(null);
var builder = new AstBuilder(source, language, moduleInfo, moduleResolver);
var parsedExpression = parseExpressionNode(expression, source);
var builder = new AstBuilder(source, language, moduleInfo, moduleResolver);
var exprNode = builder.visitExpr(parsedExpression);
var rootNode =
new SimpleRootNode(
language, new FrameDescriptor(), exprNode.getSourceSection(), "", exprNode);
language,
builder.buildModuleFrameDescriptor(),
exprNode.getSourceSection(),
"",
exprNode);
var callNode = Truffle.getRuntime().createIndirectCallNode();
return callNode.call(rootNode.getCallTarget(), module, module);
}
@@ -48,7 +48,7 @@ public final class MicrobenchmarkNodes {
var runIterationsNode =
new RunIterationsNode(
VmLanguage.get(this),
new FrameDescriptor(),
codeMemberNode.getFrameDescriptor(),
(ExpressionNode) codeMemberNode.getBodyNode().deepCopy());
var callTarget = runIterationsNode.getCallTarget();
return runBenchmark(self, (iterations) -> callTarget.call(self, self, iterations));
@@ -10,24 +10,25 @@ res2 =
res3 =
let (x = 1)
let (y = 2)
x + y + x
let (y = 2)
x + y + x
res4 =
let (x = 1)
let (x = 2)
x + x
let (x = 2)
x + x
res5 =
let (price = 500) new {
lowestPrice = price - 100
averagePrice = new {
price
let (price = 500)
new {
lowestPrice = price - 100
averagePrice = new {
price
}
highestPrice = new {
["price"] = price + 100
}
}
highestPrice = new {
["price"] = price + 100
}
}
res6 =
let (str = "Pigeon".reverse())
@@ -51,8 +52,9 @@ local g = (a) ->
res9 = g.apply(3)
local h = let (b = 2)
(a) -> a + b
local h =
let (b = 2)
(a) -> a + b
res10 = h.apply(3)
@@ -71,9 +73,9 @@ res12 = new Lets {
res13 =
let (x = 1)
let (y = x)
let (z = y)
x + y + z
let (y = x)
let (z = y)
x + y + z
// x can't access y
res14 = test.catch(() -> let (x = y) let (y = 2) x + y)
@@ -1,4 +1,4 @@
res1 =
let (x = 1)
let (y = 2)
throw("ouch")
let (y = 2)
throw("ouch")
@@ -3,15 +3,11 @@ ouch
x | throw("ouch")
^^^^^^^^^^^^^
at letExpressionError1#res1.<function#2> (file:///$snippetsDir/input/errors/letExpressionError1.pkl)
x | let (y = 2)
^^^^^^^^^^^
at letExpressionError1#res1.<function#1> (file:///$snippetsDir/input/errors/letExpressionError1.pkl)
at letExpressionError1#res1 (file:///$snippetsDir/input/errors/letExpressionError1.pkl)
x | let (x = 1)
^^^^^^^^^^^
at letExpressionError1#res1 (file:///$snippetsDir/input/errors/letExpressionError1.pkl)
at letExpressionError1#res1.<let expr> (file:///$snippetsDir/input/errors/letExpressionError1.pkl)
xxx | renderer.renderDocument(value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -3,11 +3,11 @@ ouch
x | let (y = throw("ouch"))
^^^^^^^^^^^^^
at letExpressionError2#res1.<function#1> (file:///$snippetsDir/input/errors/letExpressionError2.pkl)
at letExpressionError2#res1 (file:///$snippetsDir/input/errors/letExpressionError2.pkl)
x | let (x = 1)
^^^^^^^^^^^
at letExpressionError2#res1 (file:///$snippetsDir/input/errors/letExpressionError2.pkl)
at letExpressionError2#res1.<let expr> (file:///$snippetsDir/input/errors/letExpressionError2.pkl)
xxx | renderer.renderDocument(value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -4,10 +4,6 @@ Value: "abc"
x | let (x: Float = "abc")
^^^^^
at letExpressionErrorTyped#res1.<function#1> (file:///$snippetsDir/input/errors/letExpressionErrorTyped.pkl)
x | let (x: Float = "abc")
^^^^^^^^^^^^^^^^^^^^^^
at letExpressionErrorTyped#res1 (file:///$snippetsDir/input/errors/letExpressionErrorTyped.pkl)
xxx | renderer.renderDocument(value)