From cd282cc08fb8547c37eda71f8f30a4bb35aa4c84 Mon Sep 17 00:00:00 2001 From: Stefan Marr Date: Fri, 4 Oct 2024 00:56:30 +0100 Subject: [PATCH] Add square super nodes This supports the basic multiplication of a variable with itself for local and non-local variables, i.e.: - `local * local` - `| l | [ l * l ]` It also supports multiplication and assigning result to local/non-local variables: - `b := l * l` Signed-off-by: Stefan Marr --- .../compiler/MethodGenerationContext.java | 20 ++- .../src/trufflesom/compiler/ParserAst.java | 18 +++ .../src/trufflesom/compiler/Variable.java | 52 ++++++ .../interpreter/nodes/LocalVariableNode.java | 6 +- .../nodes/NonLocalVariableNode.java | 6 +- .../LocalVariableReadSquareWriteNode.java | 84 ++++++++++ .../supernodes/LocalVariableSquareNode.java | 41 +++++ .../NonLocalVariableReadSquareWriteNode.java | 148 ++++++++++++++++++ .../NonLocalVariableSquareNode.java | 47 ++++++ tests/trufflesom/supernodes/SquareTests.java | 83 ++++++++++ 10 files changed, 499 insertions(+), 6 deletions(-) create mode 100644 src/trufflesom/src/trufflesom/interpreter/supernodes/LocalVariableReadSquareWriteNode.java create mode 100644 src/trufflesom/src/trufflesom/interpreter/supernodes/LocalVariableSquareNode.java create mode 100644 src/trufflesom/src/trufflesom/interpreter/supernodes/NonLocalVariableReadSquareWriteNode.java create mode 100644 src/trufflesom/src/trufflesom/interpreter/supernodes/NonLocalVariableSquareNode.java create mode 100644 tests/trufflesom/supernodes/SquareTests.java diff --git a/src/trufflesom/src/trufflesom/compiler/MethodGenerationContext.java b/src/trufflesom/src/trufflesom/compiler/MethodGenerationContext.java index bd04a9a99..09648ff3d 100644 --- a/src/trufflesom/src/trufflesom/compiler/MethodGenerationContext.java +++ b/src/trufflesom/src/trufflesom/compiler/MethodGenerationContext.java @@ -28,9 +28,6 @@ import static trufflesom.vm.SymbolTable.strBlockSelf; import static trufflesom.vm.SymbolTable.strFrameOnStack; import static trufflesom.vm.SymbolTable.strSelf; -import static trufflesom.vm.SymbolTable.symBlockSelf; -import static trufflesom.vm.SymbolTable.symFrameOnStack; -import static trufflesom.vm.SymbolTable.symSelf; import static trufflesom.vm.SymbolTable.symbolFor; import java.util.ArrayList; @@ -59,6 +56,8 @@ import trufflesom.interpreter.nodes.ReturnNonLocalNode.CatchNonLocalReturnNode; import trufflesom.interpreter.nodes.literals.BlockNode; import trufflesom.interpreter.supernodes.IntIncrementNode; +import trufflesom.interpreter.supernodes.LocalVariableSquareNode; +import trufflesom.interpreter.supernodes.NonLocalVariableSquareNode; import trufflesom.primitives.Primitives; import trufflesom.vmobjects.SClass; import trufflesom.vmobjects.SInvokable; @@ -404,7 +403,20 @@ public ExpressionNode getLocalReadNode(final Variable variable, final long coord public ExpressionNode getLocalWriteNode(final Variable variable, final ExpressionNode valExpr, final long coord) { - return variable.getWriteNode(getContextLevel(variable), valExpr, coord); + int ctxLevel = getContextLevel(variable); + + if (ctxLevel == 0) { + if (valExpr instanceof LocalVariableSquareNode l) { + return variable.getReadSquareWriteNode(ctxLevel, coord, l.getLocal(), 0); + } + } + + if (valExpr instanceof NonLocalVariableSquareNode nl) { + return variable.getReadSquareWriteNode(ctxLevel, coord, nl.getLocal(), + nl.getContextLevel()); + } + + return variable.getWriteNode(ctxLevel, valExpr, coord); } protected Local getLocal(final String varName) { diff --git a/src/trufflesom/src/trufflesom/compiler/ParserAst.java b/src/trufflesom/src/trufflesom/compiler/ParserAst.java index 26db94888..d17ceb579 100644 --- a/src/trufflesom/src/trufflesom/compiler/ParserAst.java +++ b/src/trufflesom/src/trufflesom/compiler/ParserAst.java @@ -35,7 +35,9 @@ import trufflesom.interpreter.nodes.FieldNode; import trufflesom.interpreter.nodes.FieldNode.FieldReadNode; import trufflesom.interpreter.nodes.GlobalNode; +import trufflesom.interpreter.nodes.LocalVariableNode.LocalVariableReadNode; import trufflesom.interpreter.nodes.MessageSendNode; +import trufflesom.interpreter.nodes.NonLocalVariableNode.NonLocalVariableReadNode; import trufflesom.interpreter.nodes.SequenceNode; import trufflesom.interpreter.nodes.literals.BlockNode; import trufflesom.interpreter.nodes.literals.BlockNode.BlockNodeWithContext; @@ -45,7 +47,9 @@ import trufflesom.interpreter.nodes.literals.LiteralNode; import trufflesom.interpreter.supernodes.IntIncrementNodeGen; import trufflesom.interpreter.supernodes.LocalFieldStringEqualsNode; +import trufflesom.interpreter.supernodes.LocalVariableSquareNodeGen; import trufflesom.interpreter.supernodes.NonLocalFieldStringEqualsNode; +import trufflesom.interpreter.supernodes.NonLocalVariableSquareNodeGen; import trufflesom.interpreter.supernodes.StringEqualsNodeGen; import trufflesom.primitives.Primitives; import trufflesom.vm.Globals; @@ -307,6 +311,20 @@ protected ExpressionNode binaryMessage(final MethodGenerationContext mgenc, return StringEqualsNodeGen.create(s, operand).initialize(coordWithL); } } + } else if (binSelector.equals("*")) { + if (receiver instanceof LocalVariableReadNode rcvr + && operand instanceof LocalVariableReadNode op) { + if (rcvr.isSameLocal(op)) { + return LocalVariableSquareNodeGen.create(rcvr.getLocal()).initialize(coordWithL); + } + } else if (receiver instanceof NonLocalVariableReadNode rcvr + && operand instanceof NonLocalVariableReadNode op) { + if (rcvr.isSameLocal(op)) { + assert rcvr.getContextLevel() == op.getContextLevel(); + return NonLocalVariableSquareNodeGen.create( + rcvr.getContextLevel(), rcvr.getLocal()).initialize(coordWithL); + } + } } ExpressionNode inlined = diff --git a/src/trufflesom/src/trufflesom/compiler/Variable.java b/src/trufflesom/src/trufflesom/compiler/Variable.java index 7b7b621cd..b6f1acba3 100644 --- a/src/trufflesom/src/trufflesom/compiler/Variable.java +++ b/src/trufflesom/src/trufflesom/compiler/Variable.java @@ -27,6 +27,10 @@ import trufflesom.interpreter.nodes.LocalVariableNodeFactory.LocalVariableWriteNodeGen; import trufflesom.interpreter.nodes.NonLocalVariableNodeFactory.NonLocalVariableReadNodeGen; import trufflesom.interpreter.nodes.NonLocalVariableNodeFactory.NonLocalVariableWriteNodeGen; +import trufflesom.interpreter.supernodes.LocalVariableReadSquareWriteNodeGen; +import trufflesom.interpreter.supernodes.LocalVariableSquareNodeGen; +import trufflesom.interpreter.supernodes.NonLocalVariableReadSquareWriteNodeGen; +import trufflesom.interpreter.supernodes.NonLocalVariableSquareNodeGen; import trufflesom.vm.NotYetImplementedException; import trufflesom.vmobjects.SSymbol; @@ -59,6 +63,11 @@ public String toString() { public abstract ExpressionNode getWriteNode( int contextLevel, ExpressionNode valueExpr, long coord); + public abstract ExpressionNode getSquareNode(int contextLevel, long coord); + + public abstract ExpressionNode getReadSquareWriteNode(int writeContextLevel, long coord, + Local readLocal, int readContextLevel); + protected abstract void emitPop(BytecodeMethodGenContext mgenc); protected abstract void emitPush(BytecodeMethodGenContext mgenc); @@ -136,6 +145,17 @@ public ExpressionNode getWriteNode(final int contextLevel, } } + @Override + public ExpressionNode getSquareNode(final int contextLevel, final long coord) { + throw new NotYetImplementedException(); + } + + @Override + public ExpressionNode getReadSquareWriteNode(final int writeContextLevel, final long coord, + final Local readLocal, final int readContextLevel) { + throw new NotYetImplementedException(); + } + @Override public void emitPop(final BytecodeMethodGenContext mgenc) { emitPOPARGUMENT(mgenc, (byte) index, (byte) mgenc.getContextLevel(this)); @@ -176,6 +196,25 @@ public ExpressionNode getReadNode(final int contextLevel, final long coordinate) return LocalVariableReadNodeGen.create(this).initialize(coordinate); } + @Override + public ExpressionNode getSquareNode(final int contextLevel, final long coord) { + if (contextLevel > 0) { + return NonLocalVariableSquareNodeGen.create(contextLevel, this).initialize(coord); + } + return LocalVariableSquareNodeGen.create(this).initialize(coord); + } + + @Override + public ExpressionNode getReadSquareWriteNode(final int writeContextLevel, final long coord, + final Local readLocal, final int readContextLevel) { + if (writeContextLevel > 0) { + return NonLocalVariableReadSquareWriteNodeGen.create( + writeContextLevel, this, readLocal, readContextLevel).initialize(coord); + } + assert readContextLevel == 0; + return LocalVariableReadSquareWriteNodeGen.create(this, readLocal).initialize(coord); + } + public final int getIndex() { return slotIndex; } @@ -234,6 +273,19 @@ public ExpressionNode getReadNode(final int contextLevel, final long coordinate) + "They are used directly by other nodes."); } + @Override + public ExpressionNode getSquareNode(final int contextLevel, final long coord) { + throw new UnsupportedOperationException( + "There shouldn't be any language-level square nodes for internal slots. "); + } + + @Override + public ExpressionNode getReadSquareWriteNode(final int readContextLevel, final long coord, + final Local readLocal, final int writeContextLevel) { + throw new UnsupportedOperationException( + "There shouldn't be any language-level square nodes for internal slots. "); + } + @Override public Internal split() { return new Internal(name, coord, slotIndex); diff --git a/src/trufflesom/src/trufflesom/interpreter/nodes/LocalVariableNode.java b/src/trufflesom/src/trufflesom/interpreter/nodes/LocalVariableNode.java index dc8e26b7d..d89167155 100644 --- a/src/trufflesom/src/trufflesom/interpreter/nodes/LocalVariableNode.java +++ b/src/trufflesom/src/trufflesom/interpreter/nodes/LocalVariableNode.java @@ -22,7 +22,7 @@ public abstract class LocalVariableNode extends NoPreEvalExprNode // TODO: We currently assume that there is a 1:1 mapping between lexical contexts // and frame descriptors, which is apparently not strictly true anymore in Truffle 1.0.0. // Generally, we also need to revise everything in this area and address issue SOMns#240. - private LocalVariableNode(final Local local) { + protected LocalVariableNode(final Local local) { this.local = local; this.slotIndex = local.getIndex(); } @@ -46,6 +46,10 @@ public LocalVariableReadNode(final LocalVariableReadNode node) { this(node.local); } + public boolean isSameLocal(final LocalVariableNode node) { + return local.equals(node.local); + } + @Specialization(guards = "isUninitialized(frame)") public static final SObject doNil(@SuppressWarnings("unused") final VirtualFrame frame) { return Nil.nilObject; diff --git a/src/trufflesom/src/trufflesom/interpreter/nodes/NonLocalVariableNode.java b/src/trufflesom/src/trufflesom/interpreter/nodes/NonLocalVariableNode.java index 8cd7d04aa..bdd2d9622 100644 --- a/src/trufflesom/src/trufflesom/interpreter/nodes/NonLocalVariableNode.java +++ b/src/trufflesom/src/trufflesom/interpreter/nodes/NonLocalVariableNode.java @@ -24,12 +24,16 @@ public abstract class NonLocalVariableNode extends ContextualNode protected final int slotIndex; protected final Local local; - private NonLocalVariableNode(final int contextLevel, final Local local) { + protected NonLocalVariableNode(final int contextLevel, final Local local) { super(contextLevel); this.local = local; this.slotIndex = local.getIndex(); } + public boolean isSameLocal(final NonLocalVariableNode node) { + return local.equals(node.local); + } + @Override public String getInvocationIdentifier() { return local.name; diff --git a/src/trufflesom/src/trufflesom/interpreter/supernodes/LocalVariableReadSquareWriteNode.java b/src/trufflesom/src/trufflesom/interpreter/supernodes/LocalVariableReadSquareWriteNode.java new file mode 100644 index 000000000..cc1600af1 --- /dev/null +++ b/src/trufflesom/src/trufflesom/interpreter/supernodes/LocalVariableReadSquareWriteNode.java @@ -0,0 +1,84 @@ +package trufflesom.interpreter.supernodes; + +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.frame.FrameDescriptor; +import com.oracle.truffle.api.frame.FrameSlotKind; +import com.oracle.truffle.api.frame.FrameSlotTypeException; +import com.oracle.truffle.api.frame.VirtualFrame; +import trufflesom.bdt.inlining.ScopeAdaptationVisitor; +import trufflesom.bdt.inlining.ScopeAdaptationVisitor.ScopeElement; +import trufflesom.compiler.Variable.Local; +import trufflesom.interpreter.nodes.LocalVariableNode; + + +public abstract class LocalVariableReadSquareWriteNode extends LocalVariableNode { + + protected final Local readLocal; + protected final int readIndex; + + public LocalVariableReadSquareWriteNode(final Local writeLocal, final Local readLocal) { + super(writeLocal); + this.readLocal = readLocal; + this.readIndex = readLocal.getIndex(); + } + + @Specialization(guards = {"isLongKind(frame)", "frame.isLong(readIndex)"}, + rewriteOn = {FrameSlotTypeException.class}) + public final long writeLong(final VirtualFrame frame) throws FrameSlotTypeException { + long current = frame.getLong(readIndex); + long result = Math.multiplyExact(current, current); + frame.setLong(slotIndex, result); + return result; + } + + @Specialization(guards = {"isDoubleKind(frame)", "frame.isDouble(readIndex)"}, + rewriteOn = {FrameSlotTypeException.class}) + public final double writeDouble(final VirtualFrame frame) throws FrameSlotTypeException { + double current = frame.getDouble(readIndex); + double result = current * current; + frame.setDouble(slotIndex, result); + return result; + } + + // uses frame to make sure guard is not converted to assertion + protected final boolean isLongKind(final VirtualFrame frame) { + FrameDescriptor descriptor = local.getFrameDescriptor(); + if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Long) { + return true; + } + if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Illegal) { + descriptor.setSlotKind(slotIndex, FrameSlotKind.Long); + return true; + } + return false; + } + + // uses frame to make sure guard is not converted to assertion + protected final boolean isDoubleKind(final VirtualFrame frame) { + FrameDescriptor descriptor = local.getFrameDescriptor(); + if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Double) { + return true; + } + if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Illegal) { + descriptor.setSlotKind(slotIndex, FrameSlotKind.Double); + return true; + } + return false; + } + + @Override + public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) { + ScopeElement seWrite = inliner.getAdaptedVar(local); + ScopeElement seRead = inliner.getAdaptedVar(readLocal); + + assert seWrite.contextLevel == seRead.contextLevel; + + if (seWrite.var != local || seWrite.contextLevel < 0) { + assert seRead.var != readLocal || seRead.contextLevel < 0; + replace(seWrite.var.getReadSquareWriteNode(seWrite.contextLevel, sourceCoord, + (Local) seRead.var, seRead.contextLevel)); + } else { + assert 0 == seWrite.contextLevel; + } + } +} diff --git a/src/trufflesom/src/trufflesom/interpreter/supernodes/LocalVariableSquareNode.java b/src/trufflesom/src/trufflesom/interpreter/supernodes/LocalVariableSquareNode.java new file mode 100644 index 000000000..a216ce1df --- /dev/null +++ b/src/trufflesom/src/trufflesom/interpreter/supernodes/LocalVariableSquareNode.java @@ -0,0 +1,41 @@ +package trufflesom.interpreter.supernodes; + +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.frame.FrameSlotTypeException; +import com.oracle.truffle.api.frame.VirtualFrame; +import trufflesom.bdt.inlining.ScopeAdaptationVisitor; +import trufflesom.bdt.inlining.ScopeAdaptationVisitor.ScopeElement; +import trufflesom.compiler.Variable.Local; +import trufflesom.interpreter.nodes.LocalVariableNode; + + +public abstract class LocalVariableSquareNode extends LocalVariableNode { + + public LocalVariableSquareNode(final Local variable) { + super(variable); + } + + @Specialization(guards = {"frame.isLong(slotIndex)"}, + rewriteOn = {FrameSlotTypeException.class}) + public final long doLong(final VirtualFrame frame) throws FrameSlotTypeException { + long value = frame.getLong(slotIndex); + return Math.multiplyExact(value, value); + } + + @Specialization(guards = {"frame.isDouble(slotIndex)"}, + rewriteOn = {FrameSlotTypeException.class}) + public final double doDouble(final VirtualFrame frame) throws FrameSlotTypeException { + double value = frame.getDouble(slotIndex); + return value * value; + } + + @Override + public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) { + ScopeElement se = inliner.getAdaptedVar(local); + if (se.var != local || se.contextLevel < 0) { + replace(se.var.getSquareNode(se.contextLevel, sourceCoord)); + } else { + assert 0 == se.contextLevel; + } + } +} diff --git a/src/trufflesom/src/trufflesom/interpreter/supernodes/NonLocalVariableReadSquareWriteNode.java b/src/trufflesom/src/trufflesom/interpreter/supernodes/NonLocalVariableReadSquareWriteNode.java new file mode 100644 index 000000000..5a77be480 --- /dev/null +++ b/src/trufflesom/src/trufflesom/interpreter/supernodes/NonLocalVariableReadSquareWriteNode.java @@ -0,0 +1,148 @@ +package trufflesom.interpreter.supernodes; + +import com.oracle.truffle.api.dsl.Bind; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.frame.Frame; +import com.oracle.truffle.api.frame.FrameDescriptor; +import com.oracle.truffle.api.frame.FrameSlotKind; +import com.oracle.truffle.api.frame.FrameSlotTypeException; +import com.oracle.truffle.api.frame.MaterializedFrame; +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.ExplodeLoop; +import trufflesom.bdt.inlining.ScopeAdaptationVisitor; +import trufflesom.bdt.inlining.ScopeAdaptationVisitor.ScopeElement; +import trufflesom.compiler.Variable.Local; +import trufflesom.interpreter.nodes.NonLocalVariableNode; +import trufflesom.vmobjects.SBlock; + + +public abstract class NonLocalVariableReadSquareWriteNode extends NonLocalVariableNode { + + protected final Local readLocal; + protected final int readIndex; + protected final int readContextLevel; + + public NonLocalVariableReadSquareWriteNode(final int writeContextLevel, + final Local writeLocal, + final Local readLocal, final int readContextLevel) { + super(writeContextLevel, writeLocal); + this.readLocal = readLocal; + this.readIndex = readLocal.getIndex(); + this.readContextLevel = readContextLevel; + } + + @ExplodeLoop + protected final Frame determineReadContext(final VirtualFrame frame) { + if (readContextLevel == 0) { + return frame; + } + + SBlock self = (SBlock) frame.getArguments()[0]; + int i = readContextLevel - 1; + + while (i > 0) { + self = (SBlock) self.getOuterSelf(); + i--; + } + + // Graal needs help here to see that this is always a MaterializedFrame + // so, we record explicitly a class profile + return frameType.profile(self.getContext()); + } + + @Specialization( + guards = {"isLongKind(ctx)", "ctx.isLong(readIndex)", + "contextLevel == readContextLevel"}, + rewriteOn = {FrameSlotTypeException.class}) + public final long writeLongSameContext(final VirtualFrame frame, + @Bind("determineContext(frame)") final MaterializedFrame ctx) + throws FrameSlotTypeException { + long current = ctx.getLong(readIndex); + long result = Math.multiplyExact(current, current); + + ctx.setLong(slotIndex, result); + + return result; + } + + @Specialization(guards = {"isLongKind(writeCtx)", "readCtx.isLong(readIndex)"}, + rewriteOn = {FrameSlotTypeException.class}) + public final long writeLong(final VirtualFrame frame, + @Bind("determineContext(frame)") final MaterializedFrame writeCtx, + @Bind("determineReadContext(frame)") final Frame readCtx) + throws FrameSlotTypeException { + long current = readCtx.getLong(readIndex); + long result = Math.multiplyExact(current, current); + + writeCtx.setLong(slotIndex, result); + + return result; + } + + @Specialization( + guards = {"isDoubleKind(ctx)", "ctx.isDouble(readIndex)", + "contextLevel == readContextLevel"}, + rewriteOn = {FrameSlotTypeException.class}) + public final double writeDoubleSameContext(final VirtualFrame frame, + @Bind("determineContext(frame)") final MaterializedFrame ctx) + throws FrameSlotTypeException { + double current = ctx.getDouble(readIndex); + double result = current * current; + + ctx.setDouble(slotIndex, result); + return result; + } + + @Specialization(guards = {"isDoubleKind(writeCtx)", "readCtx.isDouble(readIndex)"}, + rewriteOn = {FrameSlotTypeException.class}) + public final double writeDouble(final VirtualFrame frame, + @Bind("determineContext(frame)") final MaterializedFrame writeCtx, + @Bind("determineReadContext(frame)") final Frame readCtx) + throws FrameSlotTypeException { + double current = readCtx.getDouble(readIndex); + double result = current * current; + + writeCtx.setDouble(slotIndex, result); + return result; + } + + protected final boolean isLongKind(final VirtualFrame frame) { + FrameDescriptor descriptor = local.getFrameDescriptor(); + if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Long) { + return true; + } + if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Illegal) { + descriptor.setSlotKind(slotIndex, FrameSlotKind.Long); + return true; + } + return false; + } + + protected final boolean isDoubleKind(final VirtualFrame frame) { + FrameDescriptor descriptor = local.getFrameDescriptor(); + if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Double) { + return true; + } + if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Illegal) { + descriptor.setSlotKind(slotIndex, FrameSlotKind.Double); + return true; + } + return false; + } + + @Override + public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) { + ScopeElement seWrite = inliner.getAdaptedVar(local); + ScopeElement seRead = inliner.getAdaptedVar(readLocal); + + assert seWrite.contextLevel == seRead.contextLevel; + + if (seWrite.var != local || seWrite.contextLevel < contextLevel) { + assert seRead.var != readLocal || seRead.contextLevel < contextLevel; + replace(seWrite.var.getReadSquareWriteNode(seWrite.contextLevel, sourceCoord, + (Local) seRead.var, seRead.contextLevel)); + } else { + assert contextLevel == seWrite.contextLevel; + } + } +} diff --git a/src/trufflesom/src/trufflesom/interpreter/supernodes/NonLocalVariableSquareNode.java b/src/trufflesom/src/trufflesom/interpreter/supernodes/NonLocalVariableSquareNode.java new file mode 100644 index 000000000..c9aacc897 --- /dev/null +++ b/src/trufflesom/src/trufflesom/interpreter/supernodes/NonLocalVariableSquareNode.java @@ -0,0 +1,47 @@ +package trufflesom.interpreter.supernodes; + +import com.oracle.truffle.api.dsl.Bind; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.frame.FrameSlotTypeException; +import com.oracle.truffle.api.frame.MaterializedFrame; +import com.oracle.truffle.api.frame.VirtualFrame; +import trufflesom.bdt.inlining.ScopeAdaptationVisitor; +import trufflesom.bdt.inlining.ScopeAdaptationVisitor.ScopeElement; +import trufflesom.compiler.Variable.Local; +import trufflesom.interpreter.nodes.NonLocalVariableNode; + + +public abstract class NonLocalVariableSquareNode extends NonLocalVariableNode { + + public NonLocalVariableSquareNode(final int contextLevel, final Local local) { + super(contextLevel, local); + } + + @Specialization(guards = {"ctx.isLong(slotIndex)"}, + rewriteOn = {FrameSlotTypeException.class}) + public final long doLong(final VirtualFrame frame, + @Bind("determineContext(frame)") final MaterializedFrame ctx) + throws FrameSlotTypeException { + long current = ctx.getLong(slotIndex); + return Math.multiplyExact(current, current); + } + + @Specialization(guards = {"ctx.isDouble(slotIndex)"}, + rewriteOn = {FrameSlotTypeException.class}) + public final double doDouble(final VirtualFrame frame, + @Bind("determineContext(frame)") final MaterializedFrame ctx) + throws FrameSlotTypeException { + double current = ctx.getDouble(slotIndex); + return current * current; + } + + @Override + public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) { + ScopeElement se = inliner.getAdaptedVar(local); + if (se.var != local || se.contextLevel < contextLevel) { + replace(se.var.getSquareNode(se.contextLevel, sourceCoord)); + } else { + assert contextLevel == se.contextLevel; + } + } +} diff --git a/tests/trufflesom/supernodes/SquareTests.java b/tests/trufflesom/supernodes/SquareTests.java new file mode 100644 index 000000000..fa5beaa57 --- /dev/null +++ b/tests/trufflesom/supernodes/SquareTests.java @@ -0,0 +1,83 @@ +package trufflesom.supernodes; + +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.hamcrest.MatcherAssert.assertThat; + +import org.junit.Test; +import trufflesom.interpreter.nodes.ExpressionNode; +import trufflesom.interpreter.nodes.LocalVariableNode.LocalVariableWriteNode; +import trufflesom.interpreter.nodes.NonLocalVariableNode.NonLocalVariableWriteNode; +import trufflesom.interpreter.nodes.SequenceNode; +import trufflesom.interpreter.nodes.literals.BlockNode; +import trufflesom.interpreter.supernodes.LocalVariableReadSquareWriteNode; +import trufflesom.interpreter.supernodes.LocalVariableSquareNode; +import trufflesom.interpreter.supernodes.NonLocalVariableReadSquareWriteNode; +import trufflesom.interpreter.supernodes.NonLocalVariableSquareNode; +import trufflesom.primitives.arithmetic.MultiplicationPrim; +import trufflesom.tests.AstTestSetup; + + +public class SquareTests extends AstTestSetup { + @SuppressWarnings("unchecked") + private T assertThatMainNodeIs(final String test, final Class expectedNode) { + SequenceNode seq = (SequenceNode) parseMethod( + "test = ( | l1 l2 l3 l4 | \n" + test + " )"); + + ExpressionNode testExpr = read(seq, "expressions", 0); + assertThat(testExpr, instanceOf(expectedNode)); + return (T) testExpr; + } + + @Test + public void testJustSquareLocals() { + LocalVariableSquareNode s = + assertThatMainNodeIs("l2 * l2.", LocalVariableSquareNode.class); + assertEquals(s.getLocal().name, "l2"); + + s = assertThatMainNodeIs("l1 * l1.", LocalVariableSquareNode.class); + assertEquals(s.getLocal().name, "l1"); + + assertThatMainNodeIs("l1 * l3.", MultiplicationPrim.class); + } + + @SuppressWarnings("unchecked") + private T inBlock(final String test, final Class expectedNode) { + addField("field"); + SequenceNode seq = (SequenceNode) parseMethod( + "test: arg = ( | l1 l2 l3 l4 | \n" + test + " )"); + + BlockNode block = (BlockNode) read(seq, "expressions", 0); + ExpressionNode testExpr = + read(block.getMethod().getInvokable(), "body", ExpressionNode.class); + assertThat(testExpr, instanceOf(expectedNode)); + return (T) testExpr; + } + + @Test + public void testJustSquareNonLocals() { + NonLocalVariableSquareNode s = inBlock("[ l2 * l2 ]", NonLocalVariableSquareNode.class); + assertEquals(s.getLocal().name, "l2"); + + s = inBlock("[ l1 * l1 ]", NonLocalVariableSquareNode.class); + assertEquals(s.getLocal().name, "l1"); + + inBlock("[ l1 * l3 ]", MultiplicationPrim.class); + } + + @Test + public void testSquareAndAssignLocal() { + assertThatMainNodeIs("l1 := l2 * l2.", LocalVariableReadSquareWriteNode.class); + assertThatMainNodeIs("l2 := l2 * l2.", LocalVariableReadSquareWriteNode.class); + + assertThatMainNodeIs("l3 := l1 * l2.", LocalVariableWriteNode.class); + } + + @Test + public void testSquareAndAssignNonLocal() { + inBlock("[ l1 := l2 * l2 ]", NonLocalVariableReadSquareWriteNode.class); + inBlock("[ l2 := l2 * l2 ]", NonLocalVariableReadSquareWriteNode.class); + + inBlock("[ l3 := l1 * l2 ]", NonLocalVariableWriteNode.class); + } +}