diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index a14079160f3..0825965f36c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -28,6 +28,7 @@ import org.apache.sysds.runtime.instructions.ooc.CSVReblockOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.CtableOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.DataGenOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ParameterizedBuiltinOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; @@ -79,6 +80,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str return CentralMomentOOCInstruction.parseInstruction(str); case Ctable: return CtableOOCInstruction.parseInstruction(str); + case Rand: + return DataGenOOCInstruction.parseInstruction(str); case ParameterizedBuiltin: return ParameterizedBuiltinOOCInstruction.parseInstruction(str); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java new file mode 100644 index 00000000000..355c8ddea1e --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.UnaryOperator; +import org.apache.sysds.runtime.util.UtilFunctions; + +public class DataGenOOCInstruction extends UnaryOOCInstruction { + + private final int blen; + private Types.OpOpDG method; + + // sequence specific attributes + private final CPOperand seq_from, seq_to, seq_incr; + + public DataGenOOCInstruction(UnaryOperator op, Types.OpOpDG mthd, CPOperand in, CPOperand out, int blen, CPOperand seqFrom, + CPOperand seqTo, CPOperand seqIncr, String opcode, String istr) { + super(OOCType.Rand, op, in, out, opcode, istr); + this.blen = blen; + this.method = mthd; + this.seq_from = seqFrom; + this.seq_to = seqTo; + this.seq_incr = seqIncr; + } + + public static DataGenOOCInstruction parseInstruction(String str) { + Types.OpOpDG method = null; + String[] s = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = s[0]; + + if(opcode.equalsIgnoreCase(Opcodes.SEQUENCE.toString())) { + method = Types.OpOpDG.SEQ; + // 8 operands: rows, cols, blen, from, to, incr, outvar + InstructionUtils.checkNumFields(s, 7); + } + else + throw new NotImplementedException(); // TODO + + CPOperand out = new CPOperand(s[s.length - 1]); + UnaryOperator op = null; + + if(method == Types.OpOpDG.SEQ) { + int blen = Integer.parseInt(s[3]); + CPOperand from = new CPOperand(s[4]); + CPOperand to = new CPOperand(s[5]); + CPOperand incr = new CPOperand(s[6]); + + return new DataGenOOCInstruction(op, method, null, out, blen, from, to, incr, opcode, str); + } + else + throw new NotImplementedException(); + } + + @Override + public void processInstruction(ExecutionContext ec) { + final OOCStream qOut = createWritableStream(); + + // process specific datagen operator + if(method == Types.OpOpDG.SEQ) { + double lfrom = ec.getScalarInput(seq_from).getDoubleValue(); + double lto = ec.getScalarInput(seq_to).getDoubleValue(); + double lincr = ec.getScalarInput(seq_incr).getDoubleValue(); + + // handle default 1 to -1 for special case of from>to + lincr = LibMatrixDatagen.updateSeqIncr(lfrom, lto, lincr); + + if(LOG.isTraceEnabled()) + LOG.trace( + "Process DataGenOOCInstruction seq with seqFrom=" + lfrom + ", seqTo=" + lto + ", seqIncr" + lincr); + + final int maxK = (int) UtilFunctions.getSeqLength(lfrom, lto, lincr); + final double finalLincr = lincr; + + + submitOOCTask(() -> { + int k = 0; + double curFrom = lfrom; + double curTo; + MatrixBlock mb; + + while (k < maxK) { + long desiredLen = Math.min(blen, maxK - k); + curTo = curFrom + (desiredLen - 1) * finalLincr; + long actualLen = UtilFunctions.getSeqLength(curFrom, curTo, finalLincr); + + if (actualLen != desiredLen) { + // Then we add / subtract a small correction term + curTo += (actualLen < desiredLen) ? finalLincr / 2 : -finalLincr / 2; + + if (UtilFunctions.getSeqLength(curFrom, curTo, finalLincr) != desiredLen) + throw new DMLRuntimeException("OOC seq could not construct the right number of elements."); + } + + mb = MatrixBlock.seqOperations(curFrom, curTo, finalLincr); + qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(1 + k / blen, 1), mb)); + curFrom = mb.get(mb.getNumRows() - 1, 0) + finalLincr; + k += blen; + } + + qOut.closeInput(); + }, qOut); + } + else + throw new NotImplementedException(); + + ec.getMatrixObject(output).setStreamHandle(qOut); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 7abf593aba0..570daa21c9a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -54,7 +54,7 @@ public abstract class OOCInstruction extends Instruction { private static final AtomicInteger nextStreamId = new AtomicInteger(0); public enum OOCType { - Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin + Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/SeqTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/SeqTest.java new file mode 100644 index 00000000000..f7855b93e2d --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/SeqTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +public class SeqTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "Seq"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + SeqTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String OUTPUT_NAME = "res"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testSeq1() { + runSeqTest(0, 10, 0.1); + } + + @Test + public void testSeq2() { + runSeqTest(0, 15.9, 0.01); + } + + private void runSeqTest(double from, double to, double incr) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", Double.toString(from), Double.toString(to), Double.toString(incr), output(OUTPUT_NAME)}; + + runTest(true, false, null, -1); + + //check seq OOC + Assert.assertTrue("OOC wasn't used for seq", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.SEQUENCE)); + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", Double.toString(from), Double.toString(to), Double.toString(incr), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + MatrixBlock ret1 = TestUtils.readBinary(output(OUTPUT_NAME)); + MatrixBlock ret2 = TestUtils.readBinary(output(OUTPUT_NAME + "_target")); + + TestUtils.compareMatrices(ret1, ret2, eps); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/Seq.dml b/src/test/scripts/functions/ooc/Seq.dml new file mode 100644 index 00000000000..f596f1da7f4 --- /dev/null +++ b/src/test/scripts/functions/ooc/Seq.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +from = $1; +to = $2; +incr = $3; + +res = seq(from, to, incr); + +write(res, $4, format="binary");