Eagerly check function arguments when called from inside iterable (#778)

This mitigates an issue where lazy mappings and listings widen an existing bug.

This is a follow-up to https://github.com/apple/pkl/pull/752.
This commit is contained in:
Daniel Chao
2024-11-05 09:05:09 -08:00
committed by GitHub
parent 6d161ce1d4
commit b402463f3c
22 changed files with 215 additions and 49 deletions
@@ -1986,6 +1986,7 @@ public final class AstBuilder extends AbstractAstBuilder<Object> {
visitArgumentList(argCtx), visitArgumentList(argCtx),
MemberLookupMode.EXPLICIT_RECEIVER, MemberLookupMode.EXPLICIT_RECEIVER,
needsConst, needsConst,
symbolTable.getCurrentScope().isVisitingIterable(),
PropagateNullReceiverNodeGen.create(unavailableSourceSection(), receiver), PropagateNullReceiverNodeGen.create(unavailableSourceSection(), receiver),
GetClassNodeGen.create(null))); GetClassNodeGen.create(null)));
} }
@@ -1998,6 +1999,7 @@ public final class AstBuilder extends AbstractAstBuilder<Object> {
visitArgumentList(argCtx), visitArgumentList(argCtx),
MemberLookupMode.EXPLICIT_RECEIVER, MemberLookupMode.EXPLICIT_RECEIVER,
needsConst, needsConst,
symbolTable.getCurrentScope().isVisitingIterable(),
receiver, receiver,
GetClassNodeGen.create(null)); GetClassNodeGen.create(null));
} }
@@ -2072,7 +2074,11 @@ public final class AstBuilder extends AbstractAstBuilder<Object> {
} }
return InvokeSuperMethodNodeGen.create( return InvokeSuperMethodNodeGen.create(
sourceSection, memberName, visitArgumentList(argCtx), needsConst); sourceSection,
memberName,
symbolTable.getCurrentScope().isVisitingIterable(),
visitArgumentList(argCtx),
needsConst);
} }
// superproperty call // superproperty call
@@ -2130,7 +2136,8 @@ public final class AstBuilder extends AbstractAstBuilder<Object> {
isBaseModule, isBaseModule,
scope.isCustomThisScope(), scope.isCustomThisScope(),
scope.getConstLevel(), scope.getConstLevel(),
scope.getConstDepth()); scope.getConstDepth(),
scope.isVisitingIterable());
} }
@Override @Override
@@ -71,6 +71,6 @@ public final class LetExprNode extends ExpressionNode {
var value = valueNode.executeGeneric(frame); var value = valueNode.executeGeneric(frame);
return callNode.call(function.getThisValue(), function, value); return callNode.call(function.getThisValue(), function, false, value);
} }
} }
@@ -184,7 +184,8 @@ public final class AmendFunctionNode extends PklNode {
var arguments = new Object[frameArguments.length]; var arguments = new Object[frameArguments.length];
arguments[0] = functionToAmend.getThisValue(); arguments[0] = functionToAmend.getThisValue();
arguments[1] = functionToAmend; arguments[1] = functionToAmend;
System.arraycopy(frameArguments, 2, arguments, 2, frameArguments.length - 2); arguments[2] = false;
System.arraycopy(frameArguments, 3, arguments, 3, frameArguments.length - 3);
var valueToAmend = callNode.call(functionToAmend.getCallTarget(), arguments); var valueToAmend = callNode.call(functionToAmend.getCallTarget(), arguments);
if (!(valueToAmend instanceof VmFunction newFunctionToAmend)) { if (!(valueToAmend instanceof VmFunction newFunctionToAmend)) {
@@ -28,6 +28,7 @@ public final class InvokeMethodDirectNode extends ExpressionNode {
private final VmObjectLike owner; private final VmObjectLike owner;
@Child private ExpressionNode receiverNode; @Child private ExpressionNode receiverNode;
@Children private final ExpressionNode[] argumentNodes; @Children private final ExpressionNode[] argumentNodes;
private final boolean isInIterable;
@Child private DirectCallNode callNode; @Child private DirectCallNode callNode;
@@ -35,12 +36,14 @@ public final class InvokeMethodDirectNode extends ExpressionNode {
SourceSection sourceSection, SourceSection sourceSection,
ClassMethod method, ClassMethod method,
ExpressionNode receiverNode, ExpressionNode receiverNode,
ExpressionNode[] argumentNodes) { ExpressionNode[] argumentNodes,
boolean isInIterable) {
super(sourceSection); super(sourceSection);
this.owner = method.getOwner(); this.owner = method.getOwner();
this.receiverNode = receiverNode; this.receiverNode = receiverNode;
this.argumentNodes = argumentNodes; this.argumentNodes = argumentNodes;
this.isInIterable = isInIterable;
callNode = DirectCallNode.create(method.getCallTarget(sourceSection)); callNode = DirectCallNode.create(method.getCallTarget(sourceSection));
} }
@@ -48,11 +51,12 @@ public final class InvokeMethodDirectNode extends ExpressionNode {
@Override @Override
@ExplodeLoop @ExplodeLoop
public Object executeGeneric(VirtualFrame frame) { public Object executeGeneric(VirtualFrame frame) {
var args = new Object[2 + argumentNodes.length]; var args = new Object[3 + argumentNodes.length];
args[0] = receiverNode.executeGeneric(frame); args[0] = receiverNode.executeGeneric(frame);
args[1] = owner; args[1] = owner;
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) { for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame); args[3 + i] = argumentNodes[i].executeGeneric(frame);
} }
return callNode.call(args); return callNode.call(args);
@@ -33,29 +33,33 @@ public final class InvokeMethodLexicalNode extends ExpressionNode {
private final int levelsUp; private final int levelsUp;
@Child private DirectCallNode callNode; @Child private DirectCallNode callNode;
private final boolean isInIterable;
InvokeMethodLexicalNode( InvokeMethodLexicalNode(
SourceSection sourceSection, SourceSection sourceSection,
CallTarget callTarget, CallTarget callTarget,
int levelsUp, int levelsUp,
ExpressionNode[] argumentNodes) { ExpressionNode[] argumentNodes,
boolean isInIterable) {
super(sourceSection); super(sourceSection);
this.levelsUp = levelsUp; this.levelsUp = levelsUp;
this.argumentNodes = argumentNodes; this.argumentNodes = argumentNodes;
callNode = DirectCallNode.create(callTarget); callNode = DirectCallNode.create(callTarget);
this.isInIterable = isInIterable;
} }
@Override @Override
@ExplodeLoop @ExplodeLoop
public Object executeGeneric(VirtualFrame frame) { public Object executeGeneric(VirtualFrame frame) {
var args = new Object[2 + argumentNodes.length]; var args = new Object[3 + argumentNodes.length];
var enclosingFrame = getEnclosingFrame(frame); var enclosingFrame = getEnclosingFrame(frame);
args[0] = VmUtils.getReceiver(enclosingFrame); args[0] = VmUtils.getReceiver(enclosingFrame);
args[1] = VmUtils.getOwner(enclosingFrame); args[1] = VmUtils.getOwner(enclosingFrame);
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) { for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame); args[3 + i] = argumentNodes[i].executeGeneric(frame);
} }
return callNode.call(args); return callNode.call(args);
@@ -36,6 +36,7 @@ import org.pkl.core.runtime.VmClass;
import org.pkl.core.runtime.VmFunction; import org.pkl.core.runtime.VmFunction;
/** A virtual method call. */ /** A virtual method call. */
@SuppressWarnings("DuplicatedCode")
@ImportStatic(Identifier.class) @ImportStatic(Identifier.class)
@NodeChild(value = "receiverNode", type = ExpressionNode.class) @NodeChild(value = "receiverNode", type = ExpressionNode.class)
@NodeChild(value = "receiverClassNode", type = GetClassNode.class, executeWith = "receiverNode") @NodeChild(value = "receiverClassNode", type = GetClassNode.class, executeWith = "receiverNode")
@@ -44,27 +45,31 @@ public abstract class InvokeMethodVirtualNode extends ExpressionNode {
@Children private final ExpressionNode[] argumentNodes; @Children private final ExpressionNode[] argumentNodes;
private final MemberLookupMode lookupMode; private final MemberLookupMode lookupMode;
private final boolean needsConst; private final boolean needsConst;
private final boolean isInIterable;
protected InvokeMethodVirtualNode( protected InvokeMethodVirtualNode(
SourceSection sourceSection, SourceSection sourceSection,
Identifier methodName, Identifier methodName,
ExpressionNode[] argumentNodes, ExpressionNode[] argumentNodes,
MemberLookupMode lookupMode, MemberLookupMode lookupMode,
boolean needsConst) { boolean needsConst,
boolean isInIterable) {
super(sourceSection); super(sourceSection);
this.methodName = methodName; this.methodName = methodName;
this.argumentNodes = argumentNodes; this.argumentNodes = argumentNodes;
this.lookupMode = lookupMode; this.lookupMode = lookupMode;
this.needsConst = needsConst; this.needsConst = needsConst;
this.isInIterable = isInIterable;
} }
protected InvokeMethodVirtualNode( protected InvokeMethodVirtualNode(
SourceSection sourceSection, SourceSection sourceSection,
Identifier methodName, Identifier methodName,
ExpressionNode[] argumentNodes, ExpressionNode[] argumentNodes,
MemberLookupMode lookupMode) { MemberLookupMode lookupMode,
this(sourceSection, methodName, argumentNodes, lookupMode, false); boolean isInIterable) {
this(sourceSection, methodName, argumentNodes, lookupMode, false, isInIterable);
} }
/** /**
@@ -84,11 +89,12 @@ public abstract class InvokeMethodVirtualNode extends ExpressionNode {
RootCallTarget cachedCallTarget, RootCallTarget cachedCallTarget,
@Cached("create(cachedCallTarget)") DirectCallNode callNode) { @Cached("create(cachedCallTarget)") DirectCallNode callNode) {
var args = new Object[2 + argumentNodes.length]; var args = new Object[3 + argumentNodes.length];
args[0] = receiver.getThisValue(); args[0] = receiver.getThisValue();
args[1] = receiver; args[1] = receiver;
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) { for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame); args[3 + i] = argumentNodes[i].executeGeneric(frame);
} }
return callNode.call(args); return callNode.call(args);
@@ -103,11 +109,12 @@ public abstract class InvokeMethodVirtualNode extends ExpressionNode {
@SuppressWarnings("unused") VmClass receiverClass, @SuppressWarnings("unused") VmClass receiverClass,
@Exclusive @Cached("create()") IndirectCallNode callNode) { @Exclusive @Cached("create()") IndirectCallNode callNode) {
var args = new Object[2 + argumentNodes.length]; var args = new Object[3 + argumentNodes.length];
args[0] = receiver.getThisValue(); args[0] = receiver.getThisValue();
args[1] = receiver; args[1] = receiver;
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) { for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame); args[3 + i] = argumentNodes[i].executeGeneric(frame);
} }
return callNode.call(receiver.getCallTarget(), args); return callNode.call(receiver.getCallTarget(), args);
@@ -123,11 +130,12 @@ public abstract class InvokeMethodVirtualNode extends ExpressionNode {
@Cached("resolveMethod(receiverClass)") ClassMethod method, @Cached("resolveMethod(receiverClass)") ClassMethod method,
@Cached("create(method.getCallTarget(sourceSection))") DirectCallNode callNode) { @Cached("create(method.getCallTarget(sourceSection))") DirectCallNode callNode) {
var args = new Object[2 + argumentNodes.length]; var args = new Object[3 + argumentNodes.length];
args[0] = receiver; args[0] = receiver;
args[1] = method.getOwner(); args[1] = method.getOwner();
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) { for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame); args[3 + i] = argumentNodes[i].executeGeneric(frame);
} }
return callNode.call(args); return callNode.call(args);
@@ -142,11 +150,12 @@ public abstract class InvokeMethodVirtualNode extends ExpressionNode {
@Exclusive @Cached("create()") IndirectCallNode callNode) { @Exclusive @Cached("create()") IndirectCallNode callNode) {
var method = resolveMethod(receiverClass); var method = resolveMethod(receiverClass);
var args = new Object[2 + argumentNodes.length]; var args = new Object[3 + argumentNodes.length];
args[0] = receiver; args[0] = receiver;
args[1] = method.getOwner(); args[1] = method.getOwner();
args[2] = isInIterable;
for (var i = 0; i < argumentNodes.length; i++) { for (var i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame); args[3 + i] = argumentNodes[i].executeGeneric(frame);
} }
// Deprecation should not report here (getCallTarget(sourceSection)), as this happens for each // Deprecation should not report here (getCallTarget(sourceSection)), as this happens for each
@@ -30,15 +30,18 @@ import org.pkl.core.runtime.VmUtils;
public abstract class InvokeSuperMethodNode extends ExpressionNode { public abstract class InvokeSuperMethodNode extends ExpressionNode {
private final Identifier methodName; private final Identifier methodName;
@Children private final ExpressionNode[] argumentNodes; @Children private final ExpressionNode[] argumentNodes;
private final boolean isInIterable;
private final boolean needsConst; private final boolean needsConst;
protected InvokeSuperMethodNode( protected InvokeSuperMethodNode(
SourceSection sourceSection, SourceSection sourceSection,
Identifier methodName, Identifier methodName,
boolean isInIterable,
ExpressionNode[] argumentNodes, ExpressionNode[] argumentNodes,
boolean needsConst) { boolean needsConst) {
super(sourceSection); super(sourceSection);
this.isInIterable = isInIterable;
this.needsConst = needsConst; this.needsConst = needsConst;
assert !methodName.isLocalMethod(); assert !methodName.isLocalMethod();
@@ -54,11 +57,12 @@ public abstract class InvokeSuperMethodNode extends ExpressionNode {
@Cached(value = "findSupermethod(frame)", neverDefault = true) ClassMethod supermethod, @Cached(value = "findSupermethod(frame)", neverDefault = true) ClassMethod supermethod,
@Cached("create(supermethod.getCallTarget(sourceSection))") DirectCallNode callNode) { @Cached("create(supermethod.getCallTarget(sourceSection))") DirectCallNode callNode) {
var args = new Object[2 + argumentNodes.length]; var args = new Object[3 + argumentNodes.length];
args[0] = VmUtils.getReceiverOrNull(frame); args[0] = VmUtils.getReceiverOrNull(frame);
args[1] = supermethod.getOwner(); args[1] = supermethod.getOwner();
args[2] = isInIterable;
for (int i = 0; i < argumentNodes.length; i++) { for (int i = 0; i < argumentNodes.length; i++) {
args[2 + i] = argumentNodes[i].executeGeneric(frame); args[3 + i] = argumentNodes[i].executeGeneric(frame);
} }
return callNode.call(args); return callNode.call(args);
@@ -50,6 +50,7 @@ public final class ResolveMethodNode extends ExpressionNode {
private final boolean isCustomThisScope; private final boolean isCustomThisScope;
private final ConstLevel constLevel; private final ConstLevel constLevel;
private final int constDepth; private final int constDepth;
private final boolean isInIterable;
public ResolveMethodNode( public ResolveMethodNode(
SourceSection sourceSection, SourceSection sourceSection,
@@ -58,7 +59,8 @@ public final class ResolveMethodNode extends ExpressionNode {
boolean isBaseModule, boolean isBaseModule,
boolean isCustomThisScope, boolean isCustomThisScope,
ConstLevel constLevel, ConstLevel constLevel,
int constDepth) { int constDepth,
boolean isInIterable) {
super(sourceSection); super(sourceSection);
@@ -68,6 +70,7 @@ public final class ResolveMethodNode extends ExpressionNode {
this.isCustomThisScope = isCustomThisScope; this.isCustomThisScope = isCustomThisScope;
this.constLevel = constLevel; this.constLevel = constLevel;
this.constDepth = constDepth; this.constDepth = constDepth;
this.isInIterable = isInIterable;
} }
@Override @Override
@@ -91,7 +94,11 @@ public final class ResolveMethodNode extends ExpressionNode {
assert localMethod.isLocal(); assert localMethod.isLocal();
checkConst(currOwner, localMethod, levelsUp); checkConst(currOwner, localMethod, levelsUp);
return new InvokeMethodLexicalNode( return new InvokeMethodLexicalNode(
sourceSection, localMethod.getCallTarget(sourceSection), levelsUp, argumentNodes); sourceSection,
localMethod.getCallTarget(sourceSection),
levelsUp,
argumentNodes,
isInIterable);
} }
var method = currOwner.getVmClass().getDeclaredMethod(methodName); var method = currOwner.getVmClass().getDeclaredMethod(methodName);
if (method != null) { if (method != null) {
@@ -99,7 +106,11 @@ public final class ResolveMethodNode extends ExpressionNode {
checkConst(currOwner, method, levelsUp); checkConst(currOwner, method, levelsUp);
if (method.getDeclaringClass().isClosed()) { if (method.getDeclaringClass().isClosed()) {
return new InvokeMethodLexicalNode( return new InvokeMethodLexicalNode(
sourceSection, method.getCallTarget(sourceSection), levelsUp, argumentNodes); sourceSection,
method.getCallTarget(sourceSection),
levelsUp,
argumentNodes,
isInIterable);
} }
//noinspection ConstantConditions //noinspection ConstantConditions
@@ -108,6 +119,7 @@ public final class ResolveMethodNode extends ExpressionNode {
methodName, methodName,
argumentNodes, argumentNodes,
MemberLookupMode.IMPLICIT_LEXICAL, MemberLookupMode.IMPLICIT_LEXICAL,
isInIterable,
levelsUp == 0 ? new GetReceiverNode() : new GetEnclosingReceiverNode(levelsUp), levelsUp == 0 ? new GetReceiverNode() : new GetEnclosingReceiverNode(levelsUp),
GetClassNodeGen.create(null)); GetClassNodeGen.create(null));
} }
@@ -122,7 +134,7 @@ public final class ResolveMethodNode extends ExpressionNode {
(CallTarget) localMethod.getCallTarget().call(currOwner, currOwner); (CallTarget) localMethod.getCallTarget().call(currOwner, currOwner);
return new InvokeMethodLexicalNode( return new InvokeMethodLexicalNode(
sourceSection, methodCallTarget, levelsUp, argumentNodes); sourceSection, methodCallTarget, levelsUp, argumentNodes, isInIterable);
} }
} }
@@ -138,7 +150,7 @@ public final class ResolveMethodNode extends ExpressionNode {
if (method != null) { if (method != null) {
assert !method.isLocal(); assert !method.isLocal();
return new InvokeMethodDirectNode( return new InvokeMethodDirectNode(
sourceSection, method, new ConstantValueNode(baseModule), argumentNodes); sourceSection, method, new ConstantValueNode(baseModule), argumentNodes, isInIterable);
} }
} }
@@ -158,6 +170,7 @@ public final class ResolveMethodNode extends ExpressionNode {
argumentNodes, argumentNodes,
MemberLookupMode.IMPLICIT_THIS, MemberLookupMode.IMPLICIT_THIS,
needsConst, needsConst,
isInIterable,
VmUtils.createThisNode(VmUtils.unavailableSourceSection(), isCustomThisScope), VmUtils.createThisNode(VmUtils.unavailableSourceSection(), isCustomThisScope),
GetClassNodeGen.create(null)); GetClassNodeGen.create(null));
} }
@@ -80,6 +80,7 @@ public abstract class ToStringNode extends UnaryExpressionNode {
Identifier.TO_STRING, Identifier.TO_STRING,
new ExpressionNode[] {}, new ExpressionNode[] {},
MemberLookupMode.EXPLICIT_RECEIVER, MemberLookupMode.EXPLICIT_RECEIVER,
false,
null, null,
null); null);
} }
@@ -33,12 +33,12 @@ public abstract class ApplyVmFunction0Node extends PklNode {
RootCallTarget cachedCallTarget, RootCallTarget cachedCallTarget,
@Cached("create(cachedCallTarget)") DirectCallNode callNode) { @Cached("create(cachedCallTarget)") DirectCallNode callNode) {
return callNode.call(function.getThisValue(), function); return callNode.call(function.getThisValue(), function, false);
} }
@Specialization(replaces = "evalDirect") @Specialization(replaces = "evalDirect")
protected Object eval(VmFunction function, @Cached("create()") IndirectCallNode callNode) { protected Object eval(VmFunction function, @Cached("create()") IndirectCallNode callNode) {
return callNode.call(function.getCallTarget(), function.getThisValue(), function); return callNode.call(function.getCallTarget(), function.getThisValue(), function, false);
} }
} }
@@ -77,13 +77,13 @@ public abstract class ApplyVmFunction1Node extends ExpressionNode {
RootCallTarget cachedCallTarget, RootCallTarget cachedCallTarget,
@Cached("create(cachedCallTarget)") DirectCallNode callNode) { @Cached("create(cachedCallTarget)") DirectCallNode callNode) {
return callNode.call(function.getThisValue(), function, arg1); return callNode.call(function.getThisValue(), function, false, arg1);
} }
@Specialization(replaces = "evalDirect") @Specialization(replaces = "evalDirect")
protected Object eval( protected Object eval(
VmFunction function, Object arg1, @Cached("create()") IndirectCallNode callNode) { VmFunction function, Object arg1, @Cached("create()") IndirectCallNode callNode) {
return callNode.call(function.getCallTarget(), function.getThisValue(), function, arg1); return callNode.call(function.getCallTarget(), function.getThisValue(), function, false, arg1);
} }
} }
@@ -76,7 +76,7 @@ public abstract class ApplyVmFunction2Node extends PklNode {
RootCallTarget cachedCallTarget, RootCallTarget cachedCallTarget,
@Cached("create(cachedCallTarget)") DirectCallNode callNode) { @Cached("create(cachedCallTarget)") DirectCallNode callNode) {
return callNode.call(function.getThisValue(), function, arg1, arg2); return callNode.call(function.getThisValue(), function, false, arg1, arg2);
} }
@Specialization(replaces = "evalDirect") @Specialization(replaces = "evalDirect")
@@ -86,6 +86,7 @@ public abstract class ApplyVmFunction2Node extends PklNode {
Object arg2, Object arg2,
@Cached("create()") IndirectCallNode callNode) { @Cached("create()") IndirectCallNode callNode) {
return callNode.call(function.getCallTarget(), function.getThisValue(), function, arg1, arg2); return callNode.call(
function.getCallTarget(), function.getThisValue(), function, false, arg1, arg2);
} }
} }
@@ -36,7 +36,7 @@ public abstract class ApplyVmFunction3Node extends PklNode {
RootCallTarget cachedCallTarget, RootCallTarget cachedCallTarget,
@Cached("create(cachedCallTarget)") DirectCallNode callNode) { @Cached("create(cachedCallTarget)") DirectCallNode callNode) {
return callNode.call(function.getThisValue(), function, arg1, arg2, arg3); return callNode.call(function.getThisValue(), function, false, arg1, arg2, arg3);
} }
@Specialization(replaces = "evalDirect") @Specialization(replaces = "evalDirect")
@@ -48,6 +48,6 @@ public abstract class ApplyVmFunction3Node extends PklNode {
@Cached("create()") IndirectCallNode callNode) { @Cached("create()") IndirectCallNode callNode) {
return callNode.call( return callNode.call(
function.getCallTarget(), function.getThisValue(), function, arg1, arg2, arg3); function.getCallTarget(), function.getThisValue(), function, false, arg1, arg2, arg3);
} }
} }
@@ -38,7 +38,7 @@ public abstract class ApplyVmFunction4Node extends PklNode {
RootCallTarget cachedCallTarget, RootCallTarget cachedCallTarget,
@Cached("create(cachedCallTarget)") DirectCallNode callNode) { @Cached("create(cachedCallTarget)") DirectCallNode callNode) {
return callNode.call(function.getThisValue(), function, arg1, arg2, arg3, arg4); return callNode.call(function.getThisValue(), function, false, arg1, arg2, arg3, arg4);
} }
@Specialization(replaces = "evalDirect") @Specialization(replaces = "evalDirect")
@@ -51,6 +51,6 @@ public abstract class ApplyVmFunction4Node extends PklNode {
@Cached("create()") IndirectCallNode callNode) { @Cached("create()") IndirectCallNode callNode) {
return callNode.call( return callNode.call(
function.getCallTarget(), function.getThisValue(), function, arg1, arg2, arg3, arg4); function.getCallTarget(), function.getThisValue(), function, false, arg1, arg2, arg3, arg4);
} }
} }
@@ -39,7 +39,7 @@ public abstract class ApplyVmFunction5Node extends PklNode {
RootCallTarget cachedCallTarget, RootCallTarget cachedCallTarget,
@Cached("create(cachedCallTarget)") DirectCallNode callNode) { @Cached("create(cachedCallTarget)") DirectCallNode callNode) {
return callNode.call(function.getThisValue(), function, arg1, arg2, arg3, arg4, arg5); return callNode.call(function.getThisValue(), function, false, arg1, arg2, arg3, arg4, arg5);
} }
@Specialization(replaces = "evalDirect") @Specialization(replaces = "evalDirect")
@@ -53,6 +53,14 @@ public abstract class ApplyVmFunction5Node extends PklNode {
@Cached("create()") IndirectCallNode callNode) { @Cached("create()") IndirectCallNode callNode) {
return callNode.call( return callNode.call(
function.getCallTarget(), function.getThisValue(), function, arg1, arg2, arg3, arg4, arg5); function.getCallTarget(),
function.getThisValue(),
function,
false,
arg1,
arg2,
arg3,
arg4,
arg5);
} }
} }
@@ -44,7 +44,11 @@ public final class FunctionNode extends RegularMemberNode {
// For VmObject receivers, the owner is the same as or an ancestor of the receiver. // For VmObject receivers, the owner is the same as or an ancestor of the receiver.
// For other receivers, the owner is the prototype of the receiver's class. // For other receivers, the owner is the prototype of the receiver's class.
// The chain of enclosing owners forms a function/property's lexical scope. // The chain of enclosing owners forms a function/property's lexical scope.
private static final int IMPLICIT_PARAM_COUNT = 2; //
// For function calls only, a third implicit argument is passed; whether the call came from within
// an iterable node or not.
// This is a mitigation for an existing bug (https://github.com/apple/pkl/issues/741).
private static final int IMPLICIT_PARAM_COUNT = 3;
private final int paramCount; private final int paramCount;
private final int totalParamCount; private final int totalParamCount;
@@ -109,10 +113,15 @@ public final class FunctionNode extends RegularMemberNode {
throw wrongArgumentCount(totalArgCount - IMPLICIT_PARAM_COUNT); throw wrongArgumentCount(totalArgCount - IMPLICIT_PARAM_COUNT);
} }
var isInIterable = (boolean) frame.getArguments()[2];
try { try {
for (var i = 0; i < parameterTypeNodes.length; i++) { for (var i = 0; i < parameterTypeNodes.length; i++) {
var argument = frame.getArguments()[IMPLICIT_PARAM_COUNT + i]; var argument = frame.getArguments()[IMPLICIT_PARAM_COUNT + i];
parameterTypeNodes[i].executeAndSet(frame, argument); if (isInIterable) {
parameterTypeNodes[i].executeEagerlyAndSet(frame, argument);
} else {
parameterTypeNodes[i].executeAndSet(frame, argument);
}
} }
var result = bodyNode.executeGeneric(frame); var result = bodyNode.executeGeneric(frame);
@@ -55,16 +55,16 @@ public final class IdentityMixinNode extends PklRootNode {
@Override @Override
public Object execute(VirtualFrame frame) { public Object execute(VirtualFrame frame) {
var arguments = frame.getArguments(); var arguments = frame.getArguments();
if (arguments.length != 3) { if (arguments.length != 4) {
CompilerDirectives.transferToInterpreter(); CompilerDirectives.transferToInterpreter();
throw exceptionBuilder() throw exceptionBuilder()
.evalError("wrongFunctionArgumentCount", 1, arguments.length - 2) .evalError("wrongFunctionArgumentCount", 1, arguments.length - 3)
.withSourceSection(sourceSection) .withSourceSection(sourceSection)
.build(); .build();
} }
try { try {
var argument = arguments[2]; var argument = arguments[3];
if (argumentTypeNode != null) { if (argumentTypeNode != null) {
return argumentTypeNode.execute(frame, argument); return argumentTypeNode.execute(frame, argument);
} }
@@ -94,6 +94,14 @@ public abstract class TypeNode extends PklNode {
return execute(frame, value); return execute(frame, value);
} }
/**
* Checks if {@code value} conforms to this type.
*
* <p>If {@code value} is conforming, sets {@code slot} to {@code value}. Otherwise, throws a
* {@link VmTypeMismatchException}.
*/
public abstract Object executeEagerlyAndSet(VirtualFrame frame, Object value);
// method arguments are used when default value contains a root node // method arguments are used when default value contains a root node
public @Nullable Object createDefaultValue( public @Nullable Object createDefaultValue(
VmLanguage language, VmLanguage language,
@@ -213,6 +221,11 @@ public abstract class TypeNode extends PklNode {
frame.setLong(slot, (long) value); frame.setLong(slot, (long) value);
return value; return value;
} }
@Override
public Object executeEagerlyAndSet(VirtualFrame frame, Object value) {
return executeAndSet(frame, value);
}
} }
public abstract static class ObjectSlotTypeNode extends FrameSlotTypeNode { public abstract static class ObjectSlotTypeNode extends FrameSlotTypeNode {
@@ -230,6 +243,13 @@ public abstract class TypeNode extends PklNode {
frame.setObject(slot, result); frame.setObject(slot, result);
return result; return result;
} }
@Override
public final Object executeEagerlyAndSet(VirtualFrame frame, Object value) {
var result = executeEagerly(frame, value);
frame.setObject(slot, result);
return result;
}
} }
/** /**
@@ -263,6 +283,13 @@ public abstract class TypeNode extends PklNode {
writeSlotNode.executeWithValue(frame, result); writeSlotNode.executeWithValue(frame, result);
return result; return result;
} }
@Override
public Object executeEagerlyAndSet(VirtualFrame frame, Object value) {
var result = executeEagerly(frame, value);
writeSlotNode.executeWithValue(frame, result);
return result;
}
} }
/** The `unknown` type. */ /** The `unknown` type. */
@@ -328,6 +355,11 @@ public abstract class TypeNode extends PklNode {
throw PklBugException.unreachableCode(); throw PklBugException.unreachableCode();
} }
@Override
public Object executeEagerlyAndSet(VirtualFrame frame, Object value) {
return executeAndSet(frame, value);
}
@Override @Override
public FrameSlotKind getFrameSlotKind() { public FrameSlotKind getFrameSlotKind() {
return FrameSlotKind.Illegal; return FrameSlotKind.Illegal;
@@ -2382,6 +2414,22 @@ public abstract class TypeNode extends PklNode {
} }
} }
/** See docstring on {@link TypeAliasTypeNode#execute}. */
@Override
public Object executeEagerlyAndSet(VirtualFrame frame, Object value) {
var prevOwner = VmUtils.getOwner(frame);
var prevReceiver = VmUtils.getReceiver(frame);
VmUtils.setOwner(frame, VmUtils.getOwner(typeAlias.getEnclosingFrame()));
VmUtils.setReceiver(frame, VmUtils.getReceiver(typeAlias.getEnclosingFrame()));
try {
return aliasedTypeNode.executeEagerlyAndSet(frame, value);
} finally {
VmUtils.setOwner(frame, prevOwner);
VmUtils.setReceiver(frame, prevReceiver);
}
}
@Override @Override
@TruffleBoundary @TruffleBoundary
public @Nullable Object createDefaultValue( public @Nullable Object createDefaultValue(
@@ -2501,6 +2549,13 @@ public abstract class TypeNode extends PklNode {
return ret; return ret;
} }
@Override
public Object executeEagerlyAndSet(VirtualFrame frame, Object value) {
var ret = executeEagerly(frame, value);
childNode.executeEagerlyAndSet(frame, ret);
return ret;
}
@Override @Override
public @Nullable Object createDefaultValue( public @Nullable Object createDefaultValue(
VmLanguage language, SourceSection headerSection, String qualifiedName) { VmLanguage language, SourceSection headerSection, String qualifiedName) {
@@ -2648,6 +2703,11 @@ public abstract class TypeNode extends PklNode {
} }
} }
@Override
public Object executeEagerlyAndSet(VirtualFrame frame, Object value) {
return executeAndSet(frame, value);
}
@Override @Override
public VmClass getVmClass() { public VmClass getVmClass() {
return BaseModule.getNumberClass(); return BaseModule.getNumberClass();
@@ -2716,6 +2776,11 @@ public abstract class TypeNode extends PklNode {
return value; return value;
} }
@Override
public Object executeEagerlyAndSet(VirtualFrame frame, Object value) {
return executeAndSet(frame, value);
}
@Override @Override
public VmClass getVmClass() { public VmClass getVmClass() {
return BaseModule.getFloatClass(); return BaseModule.getFloatClass();
@@ -2756,6 +2821,11 @@ public abstract class TypeNode extends PklNode {
return value; return value;
} }
@Override
public Object executeEagerlyAndSet(VirtualFrame frame, Object value) {
return executeAndSet(frame, value);
}
@Override @Override
public VmClass getVmClass() { public VmClass getVmClass() {
return BaseModule.getBooleanClass(); return BaseModule.getBooleanClass();
@@ -55,7 +55,7 @@ public final class VmFunction extends VmObjectLike {
// if call site is a node, use ApplyVmFunction1Node.execute() or DirectCallNode.call() instead of // if call site is a node, use ApplyVmFunction1Node.execute() or DirectCallNode.call() instead of
// this method // this method
public Object apply(Object arg1) { public Object apply(Object arg1) {
return getCallTarget().call(thisValue, this, arg1); return getCallTarget().call(thisValue, this, false, arg1);
} }
public String applyString(Object arg1) { public String applyString(Object arg1) {
@@ -69,7 +69,7 @@ public final class VmFunction extends VmObjectLike {
// if call site is a node, use ApplyVmFunction2Node.execute() or DirectCallNode.call() instead of // if call site is a node, use ApplyVmFunction2Node.execute() or DirectCallNode.call() instead of
// this method // this method
public Object apply(Object arg1, Object arg2) { public Object apply(Object arg1, Object arg2) {
return getCallTarget().call(thisValue, this, arg1, arg2); return getCallTarget().call(thisValue, this, false, arg1, arg2);
} }
public VmFunction copy( public VmFunction copy(
@@ -31,11 +31,12 @@ public final class FunctionNodes {
protected Object eval(VmFunction self, VmList argList) { protected Object eval(VmFunction self, VmList argList) {
var argCount = argList.getLength(); var argCount = argList.getLength();
var args = new Object[2 + argCount]; var args = new Object[3 + argCount];
args[0] = self.getThisValue(); args[0] = self.getThisValue();
args[1] = self; args[1] = self;
args[2] = false;
var i = 2; var i = 3;
for (var arg : argList) { for (var arg : argList) {
args[i++] = arg; args[i++] = arg;
} }
@@ -31,3 +31,27 @@ res2 {
}.birds }.birds
} }
} }
res3 {
for (key, _ in Map("hello-there", 5)) {
...myself(new Listing {
new Listing {
key
}
})
}
}
res4 {
for (key, _ in Map("hello-there", 5)) {
...myself2.apply(new Listing {
new Listing {
key
}
})
}
}
function myself(l: Listing<Listing<String>>) = l
local myself2 = (l: Listing<Listing<String>>) -> l
@@ -8,3 +8,13 @@ res2 {
age = 1 age = 1
} }
} }
res3 {
new {
"hello-there"
}
}
res4 {
new {
"hello-there"
}
}