diff --git a/src/lava/lib/optimization/solvers/generic/read_gate/models.py b/src/lava/lib/optimization/solvers/generic/read_gate/models.py index 70392db9..09b28566 100644 --- a/src/lava/lib/optimization/solvers/generic/read_gate/models.py +++ b/src/lava/lib/optimization/solvers/generic/read_gate/models.py @@ -48,18 +48,18 @@ def run_spk(self): cost = self.cost_in.recv() if cost[0]: self.min_cost = cost[0] - self.cost_out.send(np.array([0])) + self.cost_out.send(np.array([0], dtype=np.int32)) elif self.solution is not None: - timestep = - np.array([self.time_step]) + timestep = - np.array([self.time_step], dtype=np.int32) if self.min_cost <= self.target_cost: self._req_pause = True - self.cost_out.send(np.array([self.min_cost])) + self.cost_out.send(np.array([self.min_cost], dtype=np.int32)) self.send_pause_request.send(timestep) - self.solution_out.send(self.solution) + self.solution_out.send(self.solution.astype(np.int32)) self.solution = None self.min_cost = None else: - self.cost_out.send(np.array([0])) + self.cost_out.send(np.array([0], dtype=np.int32)) def run_post_mgmt(self): """Execute post management phase.""" diff --git a/tests/lava/tutorials/test_tutorials.py b/tests/lava/tutorials/test_tutorials.py index 90864b88..768b1fd3 100644 --- a/tests/lava/tutorials/test_tutorials.py +++ b/tests/lava/tutorials/test_tutorials.py @@ -21,7 +21,7 @@ class TestTutorials(unittest.TestCase): system_name = platform.system().lower() def _execute_notebook(self, base_dir: str, path: str) -> \ - ty.Tuple[ty.Type[nbformat.NotebookNode], ty.List[str]]: + int: """Execute a notebook via nbconvert and collect output. Parameters @@ -33,23 +33,23 @@ def _execute_notebook(self, base_dir: str, path: str) -> \ Returns ------- - Tuple - (parsed nbformat.NotebookNode object, list of execution errors) + int + (return code) """ cwd = os.getcwd() dir_name, notebook = os.path.split(path) try: env = self._update_pythonpath(base_dir, dir_name) - nb = self._convert_and_execute_notebook(notebook, env) - errors = self._collect_errors_from_all_cells(nb) + result = self._convert_and_execute_notebook(notebook, env) + errors = self._collect_errors_from_all_cells(result) except Exception as e: nb = None errors = str(e) finally: os.chdir(cwd) - return nb, errors + return errors def _update_pythonpath(self, base_dir: str, dir_name: str) \ -> ty.Dict[str, str]: @@ -94,17 +94,16 @@ def _convert_and_execute_notebook(self, notebook: str, nb : nbformat.NotebookNode Notebook dict-like node with attribute-access """ - with tempfile.NamedTemporaryFile(mode="w+t", suffix=".ipynb") \ + with tempfile.NamedTemporaryFile(mode="w+t", suffix=".py") \ as fout: - args = ["jupyter", "nbconvert", "--to", "notebook", "--execute", - "--ExecutePreprocessor.timeout=-1", + args = ["jupyter", "nbconvert", "--to", "python", "--output", fout.name, notebook] subprocess.check_call(args, env=env) # noqa: S603 fout.seek(0) - return nbformat.read(fout, nbformat.current_nbformat) + return subprocess.run(["ipython", "-c", fout.read()], env=env) # noqa - def _collect_errors_from_all_cells(self, nb: nbformat.NotebookNode) \ + def _collect_errors_from_all_cells(self, result) \ -> ty.List[str]: """Collect errors from executed notebook. @@ -118,13 +117,9 @@ def _collect_errors_from_all_cells(self, nb: nbformat.NotebookNode) \ List Collection of errors """ - errors = [] - for cell in nb.cells: - if 'outputs' in cell: - for output in cell['outputs']: - if output.output_type == 'error': - errors.append(output) - return errors + if result.returncode != 0: + result.check_returncode() + return result.returncode def _run_notebook(self, notebook: str): """Run a specific notebook @@ -157,18 +152,10 @@ def _run_notebook(self, notebook: str): # If the notebook is found execute it and store any errors for notebook_name in discovered_notebooks: - nb, errors = self._execute_notebook( - str(tutorials_directory), - notebook_name + errors = self._execute_notebook( + str(tutorials_directory), notebook_name ) - errors_joined = "\n".join(errors) if isinstance( - errors, list) else errors - if errors: - errors_record[notebook_name] = (errors_joined, nb) - - self.assertFalse(errors_record, - "Failed to execute Jupyter Notebooks \ - with errors: \n {}".format(errors_record)) + self.assertEqual(errors, 0) finally: os.chdir(cwd)