Skip to content

Commit

Permalink
Add Square supernodes (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
smarr authored Oct 4, 2024
2 parents f9513bd + 87ef25c commit 4005487
Show file tree
Hide file tree
Showing 10 changed files with 499 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
import static trufflesom.vm.SymbolTable.strBlockSelf;
import static trufflesom.vm.SymbolTable.strFrameOnStack;
import static trufflesom.vm.SymbolTable.strSelf;
import static trufflesom.vm.SymbolTable.symBlockSelf;
import static trufflesom.vm.SymbolTable.symFrameOnStack;
import static trufflesom.vm.SymbolTable.symSelf;
import static trufflesom.vm.SymbolTable.symbolFor;

import java.util.ArrayList;
Expand Down Expand Up @@ -59,6 +56,8 @@
import trufflesom.interpreter.nodes.ReturnNonLocalNode.CatchNonLocalReturnNode;
import trufflesom.interpreter.nodes.literals.BlockNode;
import trufflesom.interpreter.supernodes.IntIncrementNode;
import trufflesom.interpreter.supernodes.LocalVariableSquareNode;
import trufflesom.interpreter.supernodes.NonLocalVariableSquareNode;
import trufflesom.primitives.Primitives;
import trufflesom.vmobjects.SClass;
import trufflesom.vmobjects.SInvokable;
Expand Down Expand Up @@ -404,7 +403,18 @@ public ExpressionNode getLocalReadNode(final Variable variable, final long coord

public ExpressionNode getLocalWriteNode(final Variable variable,
final ExpressionNode valExpr, final long coord) {
return variable.getWriteNode(getContextLevel(variable), valExpr, coord);
int ctxLevel = getContextLevel(variable);

if (valExpr instanceof LocalVariableSquareNode l) {
return variable.getReadSquareWriteNode(ctxLevel, coord, l.getLocal(), 0);
}

if (valExpr instanceof NonLocalVariableSquareNode nl) {
return variable.getReadSquareWriteNode(ctxLevel, coord, nl.getLocal(),
nl.getContextLevel());
}

return variable.getWriteNode(ctxLevel, valExpr, coord);
}

protected Local getLocal(final String varName) {
Expand Down
18 changes: 18 additions & 0 deletions src/trufflesom/src/trufflesom/compiler/ParserAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
import trufflesom.interpreter.nodes.FieldNode;
import trufflesom.interpreter.nodes.FieldNode.FieldReadNode;
import trufflesom.interpreter.nodes.GlobalNode;
import trufflesom.interpreter.nodes.LocalVariableNode.LocalVariableReadNode;
import trufflesom.interpreter.nodes.MessageSendNode;
import trufflesom.interpreter.nodes.NonLocalVariableNode.NonLocalVariableReadNode;
import trufflesom.interpreter.nodes.SequenceNode;
import trufflesom.interpreter.nodes.literals.BlockNode;
import trufflesom.interpreter.nodes.literals.BlockNode.BlockNodeWithContext;
Expand All @@ -45,7 +47,9 @@
import trufflesom.interpreter.nodes.literals.LiteralNode;
import trufflesom.interpreter.supernodes.IntIncrementNodeGen;
import trufflesom.interpreter.supernodes.LocalFieldStringEqualsNode;
import trufflesom.interpreter.supernodes.LocalVariableSquareNodeGen;
import trufflesom.interpreter.supernodes.NonLocalFieldStringEqualsNode;
import trufflesom.interpreter.supernodes.NonLocalVariableSquareNodeGen;
import trufflesom.interpreter.supernodes.StringEqualsNodeGen;
import trufflesom.primitives.Primitives;
import trufflesom.vm.Globals;
Expand Down Expand Up @@ -307,6 +311,20 @@ protected ExpressionNode binaryMessage(final MethodGenerationContext mgenc,
return StringEqualsNodeGen.create(s, operand).initialize(coordWithL);
}
}
} else if (binSelector.equals("*")) {
if (receiver instanceof LocalVariableReadNode rcvr
&& operand instanceof LocalVariableReadNode op) {
if (rcvr.isSameLocal(op)) {
return LocalVariableSquareNodeGen.create(rcvr.getLocal()).initialize(coordWithL);
}
} else if (receiver instanceof NonLocalVariableReadNode rcvr
&& operand instanceof NonLocalVariableReadNode op) {
if (rcvr.isSameLocal(op)) {
assert rcvr.getContextLevel() == op.getContextLevel();
return NonLocalVariableSquareNodeGen.create(
rcvr.getContextLevel(), rcvr.getLocal()).initialize(coordWithL);
}
}
}

ExpressionNode inlined =
Expand Down
52 changes: 52 additions & 0 deletions src/trufflesom/src/trufflesom/compiler/Variable.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
import trufflesom.interpreter.nodes.LocalVariableNodeFactory.LocalVariableWriteNodeGen;
import trufflesom.interpreter.nodes.NonLocalVariableNodeFactory.NonLocalVariableReadNodeGen;
import trufflesom.interpreter.nodes.NonLocalVariableNodeFactory.NonLocalVariableWriteNodeGen;
import trufflesom.interpreter.supernodes.LocalVariableReadSquareWriteNodeGen;
import trufflesom.interpreter.supernodes.LocalVariableSquareNodeGen;
import trufflesom.interpreter.supernodes.NonLocalVariableReadSquareWriteNodeGen;
import trufflesom.interpreter.supernodes.NonLocalVariableSquareNodeGen;
import trufflesom.vm.NotYetImplementedException;
import trufflesom.vmobjects.SSymbol;

Expand Down Expand Up @@ -59,6 +63,11 @@ public String toString() {
public abstract ExpressionNode getWriteNode(
int contextLevel, ExpressionNode valueExpr, long coord);

public abstract ExpressionNode getSquareNode(int contextLevel, long coord);

public abstract ExpressionNode getReadSquareWriteNode(int writeContextLevel, long coord,
Local readLocal, int readContextLevel);

protected abstract void emitPop(BytecodeMethodGenContext mgenc);

protected abstract void emitPush(BytecodeMethodGenContext mgenc);
Expand Down Expand Up @@ -136,6 +145,17 @@ public ExpressionNode getWriteNode(final int contextLevel,
}
}

@Override
public ExpressionNode getSquareNode(final int contextLevel, final long coord) {
throw new NotYetImplementedException();
}

@Override
public ExpressionNode getReadSquareWriteNode(final int writeContextLevel, final long coord,
final Local readLocal, final int readContextLevel) {
throw new NotYetImplementedException();
}

@Override
public void emitPop(final BytecodeMethodGenContext mgenc) {
emitPOPARGUMENT(mgenc, (byte) index, (byte) mgenc.getContextLevel(this));
Expand Down Expand Up @@ -176,6 +196,25 @@ public ExpressionNode getReadNode(final int contextLevel, final long coordinate)
return LocalVariableReadNodeGen.create(this).initialize(coordinate);
}

@Override
public ExpressionNode getSquareNode(final int contextLevel, final long coord) {
if (contextLevel > 0) {
return NonLocalVariableSquareNodeGen.create(contextLevel, this).initialize(coord);
}
return LocalVariableSquareNodeGen.create(this).initialize(coord);
}

@Override
public ExpressionNode getReadSquareWriteNode(final int writeContextLevel, final long coord,
final Local readLocal, final int readContextLevel) {
if (writeContextLevel > 0 || readContextLevel > 0) {
return NonLocalVariableReadSquareWriteNodeGen.create(
writeContextLevel, this, readLocal, readContextLevel).initialize(coord);
}
assert readContextLevel == 0;
return LocalVariableReadSquareWriteNodeGen.create(this, readLocal).initialize(coord);
}

public final int getIndex() {
return slotIndex;
}
Expand Down Expand Up @@ -234,6 +273,19 @@ public ExpressionNode getReadNode(final int contextLevel, final long coordinate)
+ "They are used directly by other nodes.");
}

@Override
public ExpressionNode getSquareNode(final int contextLevel, final long coord) {
throw new UnsupportedOperationException(
"There shouldn't be any language-level square nodes for internal slots. ");
}

@Override
public ExpressionNode getReadSquareWriteNode(final int readContextLevel, final long coord,
final Local readLocal, final int writeContextLevel) {
throw new UnsupportedOperationException(
"There shouldn't be any language-level square nodes for internal slots. ");
}

@Override
public Internal split() {
return new Internal(name, coord, slotIndex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public abstract class LocalVariableNode extends NoPreEvalExprNode
// TODO: We currently assume that there is a 1:1 mapping between lexical contexts
// and frame descriptors, which is apparently not strictly true anymore in Truffle 1.0.0.
// Generally, we also need to revise everything in this area and address issue SOMns#240.
private LocalVariableNode(final Local local) {
protected LocalVariableNode(final Local local) {
this.local = local;
this.slotIndex = local.getIndex();
}
Expand All @@ -46,6 +46,10 @@ public LocalVariableReadNode(final LocalVariableReadNode node) {
this(node.local);
}

public boolean isSameLocal(final LocalVariableNode node) {
return local.equals(node.local);
}

@Specialization(guards = "isUninitialized(frame)")
public static final SObject doNil(@SuppressWarnings("unused") final VirtualFrame frame) {
return Nil.nilObject;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ public abstract class NonLocalVariableNode extends ContextualNode
protected final int slotIndex;
protected final Local local;

private NonLocalVariableNode(final int contextLevel, final Local local) {
protected NonLocalVariableNode(final int contextLevel, final Local local) {
super(contextLevel);
this.local = local;
this.slotIndex = local.getIndex();
}

public boolean isSameLocal(final NonLocalVariableNode node) {
return local.equals(node.local);
}

@Override
public String getInvocationIdentifier() {
return local.name;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package trufflesom.interpreter.supernodes;

import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.FrameDescriptor;
import com.oracle.truffle.api.frame.FrameSlotKind;
import com.oracle.truffle.api.frame.FrameSlotTypeException;
import com.oracle.truffle.api.frame.VirtualFrame;
import trufflesom.bdt.inlining.ScopeAdaptationVisitor;
import trufflesom.bdt.inlining.ScopeAdaptationVisitor.ScopeElement;
import trufflesom.compiler.Variable.Local;
import trufflesom.interpreter.nodes.LocalVariableNode;


public abstract class LocalVariableReadSquareWriteNode extends LocalVariableNode {

protected final Local readLocal;
protected final int readIndex;

public LocalVariableReadSquareWriteNode(final Local writeLocal, final Local readLocal) {
super(writeLocal);
this.readLocal = readLocal;
this.readIndex = readLocal.getIndex();
}

@Specialization(guards = {"isLongKind(frame)", "frame.isLong(readIndex)"},
rewriteOn = {FrameSlotTypeException.class})
public final long writeLong(final VirtualFrame frame) throws FrameSlotTypeException {
long current = frame.getLong(readIndex);
long result = Math.multiplyExact(current, current);
frame.setLong(slotIndex, result);
return result;
}

@Specialization(guards = {"isDoubleKind(frame)", "frame.isDouble(readIndex)"},
rewriteOn = {FrameSlotTypeException.class})
public final double writeDouble(final VirtualFrame frame) throws FrameSlotTypeException {
double current = frame.getDouble(readIndex);
double result = current * current;
frame.setDouble(slotIndex, result);
return result;
}

// uses frame to make sure guard is not converted to assertion
protected final boolean isLongKind(final VirtualFrame frame) {
FrameDescriptor descriptor = local.getFrameDescriptor();
if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Long) {
return true;
}
if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Illegal) {
descriptor.setSlotKind(slotIndex, FrameSlotKind.Long);
return true;
}
return false;
}

// uses frame to make sure guard is not converted to assertion
protected final boolean isDoubleKind(final VirtualFrame frame) {
FrameDescriptor descriptor = local.getFrameDescriptor();
if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Double) {
return true;
}
if (descriptor.getSlotKind(slotIndex) == FrameSlotKind.Illegal) {
descriptor.setSlotKind(slotIndex, FrameSlotKind.Double);
return true;
}
return false;
}

@Override
public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) {
ScopeElement seWrite = inliner.getAdaptedVar(local);
ScopeElement seRead = inliner.getAdaptedVar(readLocal);

assert seWrite.contextLevel == seRead.contextLevel;

if (seWrite.var != local || seWrite.contextLevel < 0) {
assert seRead.var != readLocal || seRead.contextLevel < 0;
replace(seWrite.var.getReadSquareWriteNode(seWrite.contextLevel, sourceCoord,
(Local) seRead.var, seRead.contextLevel));
} else {
assert 0 == seWrite.contextLevel;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package trufflesom.interpreter.supernodes;

import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.FrameSlotTypeException;
import com.oracle.truffle.api.frame.VirtualFrame;
import trufflesom.bdt.inlining.ScopeAdaptationVisitor;
import trufflesom.bdt.inlining.ScopeAdaptationVisitor.ScopeElement;
import trufflesom.compiler.Variable.Local;
import trufflesom.interpreter.nodes.LocalVariableNode;


public abstract class LocalVariableSquareNode extends LocalVariableNode {

public LocalVariableSquareNode(final Local variable) {
super(variable);
}

@Specialization(guards = {"frame.isLong(slotIndex)"},
rewriteOn = {FrameSlotTypeException.class})
public final long doLong(final VirtualFrame frame) throws FrameSlotTypeException {
long value = frame.getLong(slotIndex);
return Math.multiplyExact(value, value);
}

@Specialization(guards = {"frame.isDouble(slotIndex)"},
rewriteOn = {FrameSlotTypeException.class})
public final double doDouble(final VirtualFrame frame) throws FrameSlotTypeException {
double value = frame.getDouble(slotIndex);
return value * value;
}

@Override
public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) {
ScopeElement se = inliner.getAdaptedVar(local);
if (se.var != local || se.contextLevel < 0) {
replace(se.var.getSquareNode(se.contextLevel, sourceCoord));
} else {
assert 0 == se.contextLevel;
}
}
}
Loading

0 comments on commit 4005487

Please sign in to comment.