Skip to content

Commit

Permalink
adding tests
Browse files Browse the repository at this point in the history
Signed-off-by: Sarthak Aggarwal <[email protected]>
  • Loading branch information
sarthakaggarwal97 committed Aug 25, 2024
1 parent b154979 commit bf673e7
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public byte getStarTreeNodeType() throws IOException {
}

@Override
public StarTreeNode getChildForDimensionValue(long dimensionValue, boolean isStar) throws IOException {
public StarTreeNode getChildForDimensionValue(Long dimensionValue, boolean isStar) throws IOException {
// there will be no children for leaf nodes
if (isLeaf()) {
return null;
Expand All @@ -197,9 +197,11 @@ public StarTreeNode getChildForDimensionValue(long dimensionValue, boolean isSta
if (isStar) {
return handleStarNode();
}

StarTreeNode resultStarTreeNode = binarySearchChild(dimensionValue);
assert null != resultStarTreeNode;
StarTreeNode resultStarTreeNode = null;
if (null != dimensionValue) {
resultStarTreeNode = binarySearchChild(dimensionValue);
assert null != resultStarTreeNode;
}
return resultStarTreeNode;
}

Expand Down Expand Up @@ -232,11 +234,11 @@ private FixedLengthStarTreeNode binarySearchChild(long dimensionValue) throws IO
while (low <= high) {
int mid = low + (high - low) / 2;
FixedLengthStarTreeNode midNode = new FixedLengthStarTreeNode(in, mid);
long midNodeDimensionValue = midNode.getDimensionValue();
long midDimensionValue = midNode.getDimensionValue();

if (midNodeDimensionValue == dimensionValue) {
if (midDimensionValue == dimensionValue) {
return midNode;
} else if (midNodeDimensionValue < dimensionValue) {
} else if (midDimensionValue < dimensionValue) {
low = mid + 1;
} else {
high = mid - 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public interface StarTreeNode {
* @return the child node for the given dimension value or null if child is not present
* @throws IOException if an I/O error occurs while retrieving the child node
*/
StarTreeNode getChildForDimensionValue(long dimensionValue, boolean isStar) throws IOException;
StarTreeNode getChildForDimensionValue(Long dimensionValue, boolean isStar) throws IOException;

/**
* Returns an iterator over the children of the current star-tree node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@ public class StarTreeFileFormatsTests extends OpenSearchTestCase {
private IndexOutput dataOut;
private IndexInput dataIn;
private Directory directory;
private Integer maxLevels;
private static Integer dimensionValue;

@Before
public void setup() throws IOException {
directory = newFSDirectory(createTempDir());
maxLevels = randomIntBetween(2, 5);
dimensionValue = 0;
}

public void test_StarTreeNode() throws IOException {
Expand All @@ -48,10 +52,10 @@ public void test_StarTreeNode() throws IOException {
Map<Long, InMemoryTreeNode> levelOrderStarTreeNodeMap = new LinkedHashMap<>();
InMemoryTreeNode root = generateSampleTree(levelOrderStarTreeNodeMap);
StarTreeWriter starTreeWriter = new StarTreeWriter();
long starTreeDataLength = starTreeWriter.writeStarTree(dataOut, root, 7, "star-tree");
long starTreeDataLength = starTreeWriter.writeStarTree(dataOut, root, levelOrderStarTreeNodeMap.size(), "star-tree");

// asserting on the actual length of the star tree data file
assertEquals(starTreeDataLength, 247);
assertEquals(starTreeDataLength, (levelOrderStarTreeNodeMap.size() * 33L) + 16);
dataOut.close();

dataIn = directory.openInput("star-tree-data", IOContext.READONCE);
Expand Down Expand Up @@ -107,98 +111,49 @@ private void assertStarTreeNode(StarTreeNode starTreeNode, InMemoryTreeNode tree

}

private InMemoryTreeNode generateSampleTree(Map<Long, InMemoryTreeNode> levelOrderStarTreeNode) {
public InMemoryTreeNode generateSampleTree(Map<Long, InMemoryTreeNode> levelOrderStarTreeNode) {
// Create the root node
InMemoryTreeNode root = new InMemoryTreeNode();
root.dimensionId = 0;
root.startDocId = 0;
root.endDocId = 100;
root.startDocId = randomInt();
root.endDocId = randomInt();
root.childDimensionId = 1;
root.aggregatedDocId = randomInt();
root.nodeType = (byte) 0;
root.children = new HashMap<>();

levelOrderStarTreeNode.put(root.dimensionValue, root);

// Create child nodes for dimension 1
InMemoryTreeNode dim1Node1 = new InMemoryTreeNode();
dim1Node1.dimensionId = 1;
dim1Node1.dimensionValue = 1;
dim1Node1.startDocId = 0;
dim1Node1.endDocId = 50;
dim1Node1.childDimensionId = 2;
dim1Node1.aggregatedDocId = randomInt();
root.nodeType = (byte) 0;
dim1Node1.children = new HashMap<>();

InMemoryTreeNode dim1Node2 = new InMemoryTreeNode();
dim1Node2.dimensionId = 1;
dim1Node2.dimensionValue = 2;
dim1Node2.startDocId = 50;
dim1Node2.endDocId = 100;
dim1Node2.childDimensionId = 2;
dim1Node2.aggregatedDocId = randomInt();
root.nodeType = (byte) 0;
dim1Node2.children = new HashMap<>();

root.children.put(1L, dim1Node1);
root.children.put(2L, dim1Node2);

levelOrderStarTreeNode.put(dim1Node1.dimensionValue, dim1Node1);
levelOrderStarTreeNode.put(dim1Node2.dimensionValue, dim1Node2);

// Create child nodes for dimension 2
InMemoryTreeNode dim2Node1 = new InMemoryTreeNode();
dim2Node1.dimensionId = 2;
dim2Node1.dimensionValue = 3;
dim2Node1.startDocId = 0;
dim2Node1.endDocId = 25;
dim2Node1.childDimensionId = -1;
dim2Node1.aggregatedDocId = randomInt();
root.nodeType = (byte) 0;
dim2Node1.children = null;

InMemoryTreeNode dim2Node2 = new InMemoryTreeNode();
dim2Node2.dimensionId = 2;
dim2Node2.dimensionValue = 4;
dim2Node2.startDocId = 25;
dim2Node2.endDocId = 50;
dim2Node2.childDimensionId = -1;
dim2Node2.aggregatedDocId = randomInt();
root.nodeType = (byte) 0;
dim2Node2.children = null;

InMemoryTreeNode dim2Node3 = new InMemoryTreeNode();
dim2Node3.dimensionId = 2;
dim2Node3.dimensionValue = 5;
dim2Node3.startDocId = 50;
dim2Node3.endDocId = 75;
dim2Node3.childDimensionId = -1;
dim2Node3.aggregatedDocId = randomInt();
root.nodeType = (byte) 0;
dim2Node3.children = null;

InMemoryTreeNode dim2Node4 = new InMemoryTreeNode();
dim2Node4.dimensionId = 2;
dim2Node4.dimensionValue = 6;
dim2Node4.startDocId = 75;
dim2Node4.endDocId = 100;
dim2Node4.childDimensionId = -1;
dim2Node4.aggregatedDocId = randomInt();
root.nodeType = (byte) 0;
dim2Node4.children = null;
// Generate the tree recursively
generateTreeRecursively(root, 1, levelOrderStarTreeNode);

return root;
}

dim1Node1.children.put(3L, dim2Node1);
dim1Node1.children.put(4L, dim2Node2);
dim1Node2.children.put(5L, dim2Node3);
dim1Node2.children.put(6L, dim2Node4);
private void generateTreeRecursively(InMemoryTreeNode parent, int currentLevel, Map<Long, InMemoryTreeNode> levelOrderStarTreeNode) {
if (currentLevel >= this.maxLevels) {
return; // Maximum level reached, stop generating children
}

levelOrderStarTreeNode.put(dim2Node1.dimensionValue, dim2Node1);
levelOrderStarTreeNode.put(dim2Node2.dimensionValue, dim2Node2);
levelOrderStarTreeNode.put(dim2Node3.dimensionValue, dim2Node3);
levelOrderStarTreeNode.put(dim2Node4.dimensionValue, dim2Node4);
int numChildren = randomIntBetween(1, 10);

return root;
for (int i = 0; i < numChildren; i++) {
InMemoryTreeNode child = new InMemoryTreeNode();
dimensionValue++;
child.dimensionId = currentLevel;
child.dimensionValue = dimensionValue; // Assign a unique dimension value for each child
child.startDocId = randomInt();
child.endDocId = randomInt();
child.childDimensionId = (currentLevel == this.maxLevels - 1) ? -1 : (currentLevel + 1);
child.aggregatedDocId = randomInt();
child.nodeType = (byte) 0;
child.children = new HashMap<>();

parent.children.put(child.dimensionValue, child);
levelOrderStarTreeNode.put(child.dimensionValue, child);

generateTreeRecursively(child, currentLevel + 1, levelOrderStarTreeNode);
}
}

public void tearDown() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void testGetFieldInfo() {

}

private void assertFieldInfos(FieldInfo actualFieldInfo, String fieldName, Integer fieldNumber){
private void assertFieldInfos(FieldInfo actualFieldInfo, String fieldName, Integer fieldNumber) {
assertEquals(fieldName, actualFieldInfo.name);
assertEquals(fieldNumber, actualFieldInfo.number, 0);
assertFalse(actualFieldInfo.hasVectorValues());
Expand All @@ -75,5 +75,4 @@ private void assertFieldInfos(FieldInfo actualFieldInfo, String fieldName, Integ
assertFalse(actualFieldInfo.isSoftDeletesField());
}


}

0 comments on commit bf673e7

Please sign in to comment.