From 6ea5c996ba1fe83e9df23b9ea2c17218a53c351c Mon Sep 17 00:00:00 2001 From: Jonathan Date: Sun, 29 Jun 2025 12:41:11 +0200 Subject: [PATCH 1/4] Implement getNames and setNames builtin functions --- .../org/apache/sysds/common/Builtins.java | 4 +- .../parser/BuiltinFunctionExpression.java | 35 +++++++++++ .../sysds/runtime/frame/data/FrameBlock.java | 30 ++++++++++ .../BuiltinGetSetNamesTest.java | 59 +++++++++++++++++++ 4 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGetSetNamesTest/BuiltinGetSetNamesTest.java diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 423679d038c..d2360bb30cf 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -403,7 +403,9 @@ public enum Builtins { UNIQUE("unique", false, true), UPPER_TRI("upper.tri", false, true), XDUMMY1("xdummy1", true), //error handling test - XDUMMY2("xdummy2", true); //error handling test + XDUMMY2("xdummy2", true), //error handling test + GETNAMES("getNames", false, true), + SETNAMES("setNames", false, true); Builtins(String name, boolean script) { this(name, null, script, false, ReturnType.SINGLE_RETURN); diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index ae582b052b2..1310756c18c 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -751,12 +751,47 @@ else if(((ConstIdentifier) getThirdExpr().getOutput()) else raiseValidateError("Compress/DeCompress instruction not allowed in dml script"); break; + + case GETNAMES: + checkNumParameters(1); + Expression getNamesExpr = getFirstExpr(); + validFrameInput(getNamesExpr, _opcode.toString()); + + DataIdentifier getNamesOut = (DataIdentifier) getOutputs()[0]; + getNamesOut.setDataType(DataType.FRAME); + getNamesOut.setValueType(ValueType.STRING); + getNamesOut.setDimensions(1, getNamesExpr.getOutput().getDim2()); + getNamesOut.setBlocksize(getNamesExpr.getOutput().getBlocksize()); + break; + + case SETNAMES: + checkNumParameters(2); + Expression target = getFirstExpr(); + Expression nameRow = getSecondExpr(); + validFrameInput(target, _opcode + " (first parameter)"); + validFrameInput(nameRow, _opcode + " (second parameter)"); + if (nameRow.getOutput().getDim1() != 1) { + raiseValidateError("Second parameter of set names must be a single row frame", false); + } + DataIdentifier setNamesOut = (DataIdentifier) getOutputs()[0]; + setNamesOut.setDataType(DataType.FRAME); + setNamesOut.setValueType(target.getOutput().getValueType()); + setNamesOut.setDimensions(target.getOutput().getDim1(), target.getOutput().getDim2()); + setNamesOut.setBlocksize(target.getOutput().getBlocksize()); + break; default: //always unconditional raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false); } } + private void validFrameInput(Expression expr, String context) { + if (expr == null || expr.getOutput() == null || expr.getOutput().getDataType() != DataType.FRAME) { + String dtype = (expr != null && expr.getOutput() != null) ? expr.getOutput().getDataType().toString() : "null"; + raiseValidateError("Expecting frame parameter for " + context, false); + } + } + private static boolean isPowerOfTwo(long n) { return (n > 0) && ((n & (n - 1)) == 0); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 63cadb43cf4..34d01c84587 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -360,6 +360,36 @@ public void setColumnName(int index, String name) { _colnames[index] = name; } + /** + * Returns the column names of FrameBlock + * If column names are not set, return default names (e.g. "C1", "C2"...) + * @return an array of column names + * The actual function is the same to getColumnNamesAsFrame() + */ + public FrameBlock getNames() { + return getColumnNamesAsFrame(); + } + + public void setNames(FrameBlock names) { + if (names == null){ + throw new DMLRuntimeException("Input FrameBlock can not be null."); + } + if (names.getNumRows() != 1) { + throw new DMLRuntimeException("Input FrameBlock must be single line."); + } + if (names.getNumColumns() != this.getNumColumns()) { + throw new DMLRuntimeException("Number of columns does not match."); + } + this._colnames = new String[names.getNumColumns()]; + for (int j = 0; j < names.getNumColumns(); j++) { + String name = names.getString(0, j); + if (name == null) { + throw new DMLRuntimeException("Column names can not contain null values"); + } + _colnames[j] = name; + } + } + public ColumnMetadata[] getColumnMetadata() { return _colmeta; } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGetSetNamesTest/BuiltinGetSetNamesTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGetSetNamesTest/BuiltinGetSetNamesTest.java new file mode 100644 index 00000000000..f779667e53a --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGetSetNamesTest/BuiltinGetSetNamesTest.java @@ -0,0 +1,59 @@ +package org.apache.sysds.test.functions.builtin.BuiltinGetSetNamesTest; + +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class BuiltinGetSetNamesTest { + @Test + public void testGetDefaultNames() { + FrameBlock fb = new FrameBlock(3, ValueType.STRING); + FrameBlock names = fb.getNames(); + assertEquals("C1", names.getString(0, 0)); + assertEquals("C2", names.getString(0, 1)); + assertEquals("C3", names.getString(0, 2)); + } + + @Test + public void testSetAndGetCustomNames() { + FrameBlock fb = new FrameBlock(2, ValueType.STRING); + FrameBlock nameRow = new FrameBlock(2, ValueType.STRING); + nameRow.appendRow(new String[] {"name", "age"}); + + fb.setNames(nameRow); + + FrameBlock result = fb.getNames(); + assertEquals("name", result.getString(0, 0)); + assertEquals("age", result.getString(0, 1)); + } + + @Test(expected = DMLRuntimeException.class) + public void testSetNamesNullFrame() { + FrameBlock fb = new FrameBlock(2, ValueType.STRING); + fb.setNames(null); + } + + @Test(expected = DMLRuntimeException.class) + public void testSetNamesWrongRowCount() { + FrameBlock fb = new FrameBlock(2, ValueType.STRING); + FrameBlock nameRows = new FrameBlock(2, ValueType.STRING); + nameRows.appendRow(new String[] {"name", "age"}); + nameRows.appendRow(new String[] {"x", "y"}); + + fb.setNames(nameRows); + } + + @Test(expected = DMLRuntimeException.class) + public void testSetNamesWrongColCount() { + FrameBlock fb = new FrameBlock(3, ValueType.STRING); + FrameBlock nameRow = new FrameBlock(2, ValueType.STRING); + nameRow.appendRow(new String[] {"a", "b"}); + + fb.setNames(nameRow); + } +} + From a46fad315c16a5388c5c99b8a9993cc2afe0f59d Mon Sep 17 00:00:00 2001 From: YenFuChen Date: Thu, 3 Jul 2025 11:51:45 +0200 Subject: [PATCH 2/4] Add apache license header to [BuiltinGetSetNamesTest.java] --- .../BuiltinGetSetNamesTest.java | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGetSetNamesTest/BuiltinGetSetNamesTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGetSetNamesTest/BuiltinGetSetNamesTest.java index f779667e53a..0daf6a236e0 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGetSetNamesTest/BuiltinGetSetNamesTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGetSetNamesTest/BuiltinGetSetNamesTest.java @@ -1,5 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ package org.apache.sysds.test.functions.builtin.BuiltinGetSetNamesTest; - import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysds.common.Types.ValueType; From c6d0b13c33e1ae34bfb33e791b6d1bd1f636ba5c Mon Sep 17 00:00:00 2001 From: YenFuChen Date: Sun, 6 Jul 2025 12:28:17 +0200 Subject: [PATCH 3/4] Add test with dml script for setNames and getNames function --- .../part1/BuiltinGetSetNamesScriptTest.java | 76 +++++++++++++++++++ .../BuiltinGetSetNamesTest.java | 3 +- .../builtin/BuiltinGetSetNamesTest.dml | 29 +++++++ 3 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesScriptTest.java rename src/test/java/org/apache/sysds/test/functions/builtin/{BuiltinGetSetNamesTest => part1}/BuiltinGetSetNamesTest.java (95%) create mode 100644 src/test/scripts/functions/builtin/BuiltinGetSetNamesTest.dml diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesScriptTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesScriptTest.java new file mode 100644 index 00000000000..d696dfd34b3 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesScriptTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin.part1; + +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.AutomatedTestBase; +import org.junit.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import static org.junit.Assert.fail; +import static org.junit.Assert.assertArrayEquals; + +public class BuiltinGetSetNamesScriptTest extends AutomatedTestBase { + + private static final Log LOG = LogFactory.getLog(BuiltinGetSetNamesScriptTest.class); + + private static final String TEST_NAME = "BuiltinGetSetNamesTest"; + private static final String TEST_DIR = "functions/builtin/part1/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinGetSetNamesScriptTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, + new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"N"})); + } + + @Test + public void testSetNamesAndGetNames() { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml"; + programArgs = new String[] { "-args", output("N") }; + + runTest(true, false, null, -1); + + String actualOutputPath = output("N"); + String actualContent; + + try { + actualContent = new String(Files.readAllBytes(Paths.get(actualOutputPath))).trim(); + String[] actualNames = actualContent.split(","); + + String[] expectedNames = new String[]{"name", "age"}; + + assertArrayEquals("Column names mismatch.", expectedNames, actualNames); + + } catch (IOException e) { + LOG.error("Failed to read test files: " + e.getMessage(), e); + fail("Failed to read test files: " + e.getMessage()); + } + } +} + diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGetSetNamesTest/BuiltinGetSetNamesTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesTest.java similarity index 95% rename from src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGetSetNamesTest/BuiltinGetSetNamesTest.java rename to src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesTest.java index 0daf6a236e0..d924ef1cebc 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGetSetNamesTest/BuiltinGetSetNamesTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesTest.java @@ -16,9 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sysds.test.functions.builtin.BuiltinGetSetNamesTest; +package org.apache.sysds.test.functions.builtin.part1; import org.apache.sysds.runtime.frame.data.FrameBlock; -import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.junit.Test; diff --git a/src/test/scripts/functions/builtin/BuiltinGetSetNamesTest.dml b/src/test/scripts/functions/builtin/BuiltinGetSetNamesTest.dml new file mode 100644 index 00000000000..ae6c011da9a --- /dev/null +++ b/src/test/scripts/functions/builtin/BuiltinGetSetNamesTest.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = as.frame(matrix(c("Alice", "25", "Bob", "30"), rows=2), + types=["string", "string"]) + +X = setNames(X, matrix(c("name", "age"), rows=1)) + +N = getNames(X) + +write(N, $OUTPUT, format="csv") From ed81e4de93297ce02d398123751ae3bb2b0cbb3f Mon Sep 17 00:00:00 2001 From: YenFuChen Date: Wed, 6 Aug 2025 00:37:13 +0300 Subject: [PATCH 4/4] fix getColNames and setColNames functions --- .gitignore | 4 ++ .../org/apache/sysds/common/Builtins.java | 4 +- .../java/org/apache/sysds/common/Opcodes.java | 2 + .../parser/BuiltinFunctionExpression.java | 4 +- .../sysds/runtime/frame/data/FrameBlock.java | 4 +- .../cp/BinaryFrameFrameCPInstruction.java | 6 ++ .../cp/UnaryFrameCPInstruction.java | 6 ++ .../part1/BuiltinGetSetNamesScriptTest.java | 57 +++++++------------ .../builtin/BuiltinGetSetNamesTest.dml | 25 ++++++-- 9 files changed, 63 insertions(+), 49 deletions(-) diff --git a/.gitignore b/.gitignore index f3c28571bdf..6db9af8e619 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,7 @@ venv/* # resource optimization scripts/resource/output *.pem +*.log +build_log.txt +*.log +build_log.txt diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index d2360bb30cf..627e0ae2181 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -404,8 +404,8 @@ public enum Builtins { UPPER_TRI("upper.tri", false, true), XDUMMY1("xdummy1", true), //error handling test XDUMMY2("xdummy2", true), //error handling test - GETNAMES("getNames", false, true), - SETNAMES("setNames", false, true); + GETCOLNAMES("getColNames", false, true), + SETCOLNAMES("setColNames", false, true); Builtins(String name, boolean script) { this(name, null, script, false, ReturnType.SINGLE_RETURN); diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 28c5a7a6a8e..d28117ed314 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -123,6 +123,7 @@ public enum Opcodes { FREPLICATE("freplicate", InstructionType.Binary), VALUESWAP("valueSwap", InstructionType.Binary), APPLYSCHEMA("applySchema", InstructionType.Binary), + SETCOLNAMES("setColNames", InstructionType.Binary), MAP("_map", InstructionType.Ternary), NMAX("nmax", InstructionType.BuiltinNary), @@ -164,6 +165,7 @@ public enum Opcodes { TYPEOF("typeOf", InstructionType.Unary), DETECTSCHEMA("detectSchema", InstructionType.Unary), COLNAMES("colnames", InstructionType.Unary), + GETCOLNAMES("getColNames", InstructionType.Unary), ISNA("isna", InstructionType.Unary), ISNAN("isnan", InstructionType.Unary), ISINF("isinf", InstructionType.Unary), diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 1310756c18c..12f527bfaf6 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -752,7 +752,7 @@ else if(((ConstIdentifier) getThirdExpr().getOutput()) raiseValidateError("Compress/DeCompress instruction not allowed in dml script"); break; - case GETNAMES: + case GETCOLNAMES: checkNumParameters(1); Expression getNamesExpr = getFirstExpr(); validFrameInput(getNamesExpr, _opcode.toString()); @@ -764,7 +764,7 @@ else if(((ConstIdentifier) getThirdExpr().getOutput()) getNamesOut.setBlocksize(getNamesExpr.getOutput().getBlocksize()); break; - case SETNAMES: + case SETCOLNAMES: checkNumParameters(2); Expression target = getFirstExpr(); Expression nameRow = getSecondExpr(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 34d01c84587..e146e1f69b4 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -366,11 +366,11 @@ public void setColumnName(int index, String name) { * @return an array of column names * The actual function is the same to getColumnNamesAsFrame() */ - public FrameBlock getNames() { + public FrameBlock getColNames() { return getColumnNamesAsFrame(); } - public void setNames(FrameBlock names) { + public void setColNames(FrameBlock names) { if (names == null){ throw new DMLRuntimeException("Input FrameBlock can not be null."); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java index e9771b2e7fe..30a30f4464c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java @@ -63,6 +63,12 @@ else if(getOpcode().equals(Opcodes.APPLYSCHEMA.toString())) { final FrameBlock out = FrameLibApplySchema.applySchema(inBlock1, inBlock2, k); ec.setFrameOutput(output.getName(), out); } + else if (getOpcode().equals(Opcodes.SETCOLNAMES.toString())) { + FrameBlock fb = ec.getFrameInput(input1.getName()); + FrameBlock nameRow = ec.getFrameInput(input2.getName()); + fb.setColNames(nameRow); + ec.setFrameOutput(output.getName(), fb); + } else { // Execute binary operations BinaryOperator dop = (BinaryOperator) _optr; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java index 107cab79d79..a39de10d30a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java @@ -52,6 +52,12 @@ else if(getOpcode().equals(Opcodes.COLNAMES.toString())) { ec.releaseFrameInput(input1.getName()); ec.setFrameOutput(output.getName(), retBlock); } + else if (getOpcode().equals(Opcodes.GETCOLNAMES.toString())) { + FrameBlock inBlock = ec.getFrameInput(input1.getName()); + FrameBlock retBlock = inBlock.getColNames(); + ec.releaseFrameInput(input1.getName()); + ec.setFrameOutput(output.getName(), retBlock); + } else throw new DMLScriptException("Opcode '" + getOpcode() + "' is not a valid UnaryFrameCPInstruction"); } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesScriptTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesScriptTest.java index d696dfd34b3..ecc1a0763e8 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesScriptTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesScriptTest.java @@ -19,58 +19,39 @@ package org.apache.sysds.test.functions.builtin.part1; -import org.apache.sysds.test.TestConfiguration; -import org.apache.sysds.test.AutomatedTestBase; import org.junit.Test; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.common.Types.ExecMode; +import java.io.BufferedReader; +import java.io.FileReader; import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Paths; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -import static org.junit.Assert.fail; -import static org.junit.Assert.assertArrayEquals; public class BuiltinGetSetNamesScriptTest extends AutomatedTestBase { - - private static final Log LOG = LogFactory.getLog(BuiltinGetSetNamesScriptTest.class); - private static final String TEST_NAME = "BuiltinGetSetNamesTest"; - private static final String TEST_DIR = "functions/builtin/part1/"; + private static final String TEST_DIR = "functions/builtin/"; private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinGetSetNamesScriptTest.class.getSimpleName() + "/"; @Override public void setUp() { - addTestConfiguration(TEST_NAME, - new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"N"})); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B"})); + setExecMode(ExecMode.SINGLE_NODE); } @Test - public void testSetNamesAndGetNames() { - TestConfiguration config = getTestConfiguration(TEST_NAME); - loadTestConfiguration(config); - + public void testGetSetNames() { fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml"; - programArgs = new String[] { "-args", output("N") }; - + String tempFilePath = output("B"); + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + programArgs = new String[]{"-args", tempFilePath}; runTest(true, false, null, -1); - - String actualOutputPath = output("N"); - String actualContent; - - try { - actualContent = new String(Files.readAllBytes(Paths.get(actualOutputPath))).trim(); - String[] actualNames = actualContent.split(","); - - String[] expectedNames = new String[]{"name", "age"}; - - assertArrayEquals("Column names mismatch.", expectedNames, actualNames); - + try (BufferedReader br = new BufferedReader(new FileReader(tempFilePath))) { + String header = br.readLine(); + if (header == null || !header.equals("ID,Value")) { + throw new AssertionError("Test failed: Expected header 'ID,Value', but got: " + header); + } } catch (IOException e) { - LOG.error("Failed to read test files: " + e.getMessage(), e); - fail("Failed to read test files: " + e.getMessage()); + throw new AssertionError("Test failed: Unable to read output file: " + e.getMessage()); } } -} - +} \ No newline at end of file diff --git a/src/test/scripts/functions/builtin/BuiltinGetSetNamesTest.dml b/src/test/scripts/functions/builtin/BuiltinGetSetNamesTest.dml index ae6c011da9a..c8a87754abb 100644 --- a/src/test/scripts/functions/builtin/BuiltinGetSetNamesTest.dml +++ b/src/test/scripts/functions/builtin/BuiltinGetSetNamesTest.dml @@ -19,11 +19,26 @@ # #------------------------------------------------------------- -X = as.frame(matrix(c("Alice", "25", "Bob", "30"), rows=2), - types=["string", "string"]) -X = setNames(X, matrix(c("name", "age"), rows=1)) +tempFile = $1 -N = getNames(X) +data = matrix(c(1, 2, 3, 4), 2, 2) +frame1 = as.frame(data) +colNames1 = as.frame(matrix(c("ID", "Value"), 1, 2)) -write(N, $OUTPUT, format="csv") +frame1 = setColNames(frame1, colNames1) +retrievedNames = getColNames(frame1) + +if (!all(retrievedNames == colNames1[1,])) { + stop("Name mismatch: Expected " + toString(colNames1[1,]) + " but got " + toString(retrievedNames)) +} + +write(frame1, tempFile, format="csv", header=TRUE) +frame2 = read(tempFile, format="csv", header=TRUE) + +reloadedNames = getColNames(frame2) +if (!all(reloadedNames == colNames1[1,])) { + stop("CSV reload name mismatch: Expected " + toString(colNames1[1,]) + " but got " + toString(reloadedNames)) +} + +print("All tests passed successfully!")