diff --git a/.ado/ci.yml b/.ado/ci.yml index 7a943be2..a69aee59 100644 --- a/.ado/ci.yml +++ b/.ado/ci.yml @@ -1,91 +1,97 @@ trigger: -- main + - main pr: -- main + - main schedules: -- cron: "0 9 * * Sat" - displayName: 'Build for Component Governance' - branches: - include: - - main - always: true + - cron: "0 9 * * Sat" + displayName: "Build for Component Governance" + branches: + include: + - main + always: true variables: -- name: QSHARP_PYTHON_TELEMETRY - value: none # Disable usage telemetry for internal test pipelines -- name: PYTEST_MAX_PARALLEL_TESTS - value: 'auto' + - name: QSHARP_PYTHON_TELEMETRY + value: none # Disable usage telemetry for internal test pipelines + - name: PYTEST_MAX_PARALLEL_TESTS + value: "auto" jobs: -- job: "Build_Azure_Quantum_Python" - displayName: Build "azure-quantum" package - pool: - vmImage: 'windows-latest' - - steps: - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.11' - displayName: Set Python version - - - script: | - pip install wheel - displayName: Install wheel - - - script: | - pip freeze - displayName: List installed packages - - - script: | - cd $(Build.SourcesDirectory)/azure-quantum - python setup.py sdist --dist-dir=target/wheels - python setup.py bdist_wheel --dist-dir=target/wheels - displayName: Build azure-quantum package - - - publish: $(Build.SourcesDirectory)/azure-quantum/target/wheels/ - artifact: azure-quantum-wheels - displayName: Upload azure-quantum artifacts - -- job: "Test_Azure_Quantum_Python" - displayName: Test "azure-quantum" package - pool: - vmImage: 'windows-latest' - - steps: - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.11' - displayName: Set Python version - - - script: | - python -m pip install --upgrade pip - pip install pytest pytest-azurepipelines pytest-cov pytest-xdist tox - displayName: Install test dependencies - - - script: | - pip freeze - displayName: List installed packages - - - script: | - cd $(Build.SourcesDirectory)/azure-quantum - pip install ".[cirq,qsharp,dev]" - pytest --numprocesses $(PYTEST_MAX_PARALLEL_TESTS) --cov-report term --cov=azure.quantum --junitxml test-output-azure-quantum.xml --ignore tests/unit/test_qiskit.py --ignore tests/unit/test_session_qiskit.py $(Build.SourcesDirectory)/azure-quantum - displayName: Run azure-quantum unit tests - - - script: | - cd $(Build.SourcesDirectory)/azure-quantum - tox -e py311-qiskit1,py311-qiskit2 - displayName: Run Qiskit matrix tests - - - task: PublishTestResults@2 - displayName: 'Publish tests results (python)' - condition: succeededOrFailed() - inputs: - testResultsFormat: 'JUnit' - testResultsFiles: '**/test-*.xml' - testRunTitle: 'Azure Quantum Python Tests' - - - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 - displayName: 'Component detection' + - job: "Build_Azure_Quantum_Python" + displayName: Build "azure-quantum" package + pool: + vmImage: "windows-latest" + + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: "3.11" + displayName: Set Python version + + - script: | + pip install wheel + displayName: Install wheel + + - script: | + pip freeze + displayName: List installed packages + + - script: | + cd $(Build.SourcesDirectory)/azure-quantum + python setup.py sdist --dist-dir=target/wheels + python setup.py bdist_wheel --dist-dir=target/wheels + displayName: Build azure-quantum package + + - publish: $(Build.SourcesDirectory)/azure-quantum/target/wheels/ + artifact: azure-quantum-wheels + displayName: Upload azure-quantum artifacts + + - job: "Test_Azure_Quantum_Python" + displayName: Test "azure-quantum" package + pool: + vmImage: "windows-latest" + + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: "3.11" + displayName: Set Python version + + - script: | + python -m pip install --upgrade pip + pip install pytest pytest-azurepipelines pytest-cov pytest-xdist tox + displayName: Install test dependencies + + - script: | + pip freeze + displayName: List installed packages + + - script: | + cd $(Build.SourcesDirectory)/azure-quantum + pip install ".[cirq,qsharp,dev]" + pytest --numprocesses $(PYTEST_MAX_PARALLEL_TESTS) --cov-report term --cov=azure.quantum --junitxml test-output-azure-quantum.xml --ignore tests/unit/test_qiskit.py --ignore tests/unit/test_session_qiskit.py $(Build.SourcesDirectory)/azure-quantum --ignore tests/unit/local + displayName: Run azure-quantum unit tests + + - script: | + cd $(Build.SourcesDirectory)/azure-quantum + # Run local-only tests explicitly to include nested path under unit/local + pytest --numprocesses $(PYTEST_MAX_PARALLEL_TESTS) --cov=azure.quantum --cov-append --junitxml test-output-azure-quantum-local.xml tests/unit/local + displayName: Run azure-quantum local unit tests + + - script: | + cd $(Build.SourcesDirectory)/azure-quantum + tox -e py311-qiskit1,py311-qiskit2 + displayName: Run Qiskit matrix tests + + - task: PublishTestResults@2 + displayName: "Publish tests results (python)" + condition: succeededOrFailed() + inputs: + testResultsFormat: "JUnit" + testResultsFiles: "**/test-*.xml" + testRunTitle: "Azure Quantum Python Tests" + + - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 + displayName: "Component detection" diff --git a/azure-quantum/tests/unit/local/mock_client.py b/azure-quantum/tests/unit/local/mock_client.py new file mode 100644 index 00000000..36753a75 --- /dev/null +++ b/azure-quantum/tests/unit/local/mock_client.py @@ -0,0 +1,665 @@ +## +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +## + +""" +Mock Azure Quantum REST client used to back a real Workspace +without making network calls. Returns real SDK models and ItemPaged. +""" + +from typing import List, Optional +from datetime import datetime, UTC, timedelta + +from azure.core.paging import ItemPaged +from azure.quantum.workspace import Workspace +from types import SimpleNamespace +from azure.quantum._client import WorkspaceClient +from azure.quantum._client.models import JobDetails, SessionDetails, ItemDetails +from azure.quantum._client.models import SasUriResponse +from azure.quantum._workspace_connection_params import WorkspaceConnectionParams +from common import ( + SUBSCRIPTION_ID, + RESOURCE_GROUP, + LOCATION, + ENDPOINT_URI, + WORKSPACE, +) + + +def _paged(items: List, page_size: int = 100) -> ItemPaged: + """Create an ItemPaged that conforms to azure-core's contract. + - get_next(token) returns a response payload + - extract_data(response) returns (items_iterable, next_link) + """ + + def get_next(token): + start = int(token) if token is not None else 0 + end = start + page_size + page = items[start:end] + next_link = str(end) if end < len(items) else None + # Return a dict-like payload as expected by extract_data + return {"items": page, "next_link": next_link} + + def extract_data(response): + # Return (iterable, next_link) per azure.core.paging contract + if response is None: + return None, [] + items_iter = response.get("items") or [] + next_link = response.get("next_link") + # azure.core.paging expects (continuation_token, items) + return next_link, items_iter + + return ItemPaged(get_next, extract_data) + + +def _apply_filter(items: List, filter_expr: Optional[str]) -> List: + """Apply a minimal OData-like filter generated by Workspace._create_filter. + Supports: + - startswith(Name, 'prefix') + - Property eq 'value' (with or groups inside parentheses) + - CreationTime ge/le YYYY-MM-DD + Properties: Name, ItemType, JobType, ProviderId, Target, State, CreationTime + """ + if not filter_expr: + return items + + def matches(item) -> bool: + expr = filter_expr + # Handle startswith(Name, 'prefix') optionally combined with ' and ' + conds = [c.strip() for c in expr.split(" and ")] + + def eval_simple(condition: str) -> bool: + # startswith(Name, 'x') (case-sensitive to match Workspace._create_filter) + if condition.startswith("startswith("): + try: + inside = condition[len("startswith(") : -1] + prop, value = inside.split(",", 1) + prop = prop.strip() + value = value.strip().strip("'") + name = getattr(item, "name", None) + return isinstance(name, str) and name.startswith(value) + except Exception: + return False + # Parenthesized OR: (A or B or C) + if condition.startswith("(") and condition.endswith(")"): + inner = condition[1:-1] + parts = [p.strip() for p in inner.split(" or ")] + return any(eval_simple(p) for p in parts) + # Equality: Prop eq 'value' + if " eq " in condition: + try: + left, right = condition.split(" eq ", 1) + prop = left.strip() + val = right.strip().strip("'") + # Map property names to model attributes + mapping = { + "Name": "name", + "ItemType": "item_type", + "JobType": "job_type", + "ProviderId": "provider_id", + "Target": "target", + "State": "status", + } + attr = mapping.get(prop) + if not attr: + return False + item_val = getattr(item, attr, None) + return item_val == val + except Exception: + return False + # CreationTime ge/le YYYY-MM-DD + if "CreationTime ge " in condition or "CreationTime le " in condition: + try: + if " ge " in condition: + _, date_str = condition.split(" ge ", 1) + cmp_date = datetime.fromisoformat(date_str.strip()) + ct = getattr(item, "creation_time", None) + return bool(ct) and ct.date() >= cmp_date.date() + if " le " in condition: + _, date_str = condition.split(" le ", 1) + cmp_date = datetime.fromisoformat(date_str.strip()) + ct = getattr(item, "creation_time", None) + return bool(ct) and ct.date() <= cmp_date.date() + except Exception: + return False + return False + + return all(eval_simple(c) for c in conds) + + return [it for it in items if matches(it)] + + +class JobsOperations: + def __init__(self, store: List[JobDetails]) -> None: + self._store = store + + def create_or_replace( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + job_id: str, + job_details: JobDetails, + ) -> JobDetails: + # Preserve provided status; default only if missing + if getattr(job_details, "status", None) is None: + job_details.status = "Submitted" + # Ensure creation_time present + if not getattr(job_details, "creation_time", None): + job_details.creation_time = datetime.now(UTC) + # Upsert by id + for i, jd in enumerate(self._store): + if jd.id == job_id: + self._store[i] = job_details + break + else: + self._store.append(job_details) + return job_details + + # New WorkspaceClient API: create + def create( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + job_id: str, + resource: JobDetails, + ) -> JobDetails: + return self.create_or_replace( + subscription_id, + resource_group_name, + workspace_name, + job_id, + resource, + ) + + def get( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + job_id: str, + ) -> JobDetails: + for jd in self._store: + if jd.id == job_id: + return jd + raise KeyError(job_id) + + # Cancel/delete for older API; mark job as cancelled + def delete( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + job_id: str, + ) -> None: + for jd in self._store: + if jd.id == job_id: + jd.status = "Cancelled" + return None + raise KeyError(job_id) + + def list( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + filter: Optional[str] = None, + orderby: Optional[str] = None, + top: int = 100, + skip: int = 0, + ) -> ItemPaged[JobDetails]: + items = list(self._store) + # Apply filter + items = _apply_filter(items, filter) + # Only basic orderby support for CreationTime asc/desc + if orderby: + try: + prop, direction = orderby.split() + if prop == "CreationTime": + items.sort( + key=lambda j: getattr(j, "creation_time", datetime.now(UTC)), + reverse=(direction == "desc"), + ) + except Exception: + pass + return _paged(items[skip : skip + top], page_size=top) + + +class SessionsOperations: + def __init__( + self, store: List[SessionDetails], jobs_store: List[JobDetails] + ) -> None: + self._store = store + self._jobs_store = jobs_store + + def create_or_replace( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + session_id: str, + session_details: SessionDetails, + ) -> SessionDetails: + if getattr(session_details, "status", None) is None: + session_details.status = "WAITING" + if not getattr(session_details, "creation_time", None): + session_details.creation_time = datetime.utcnow() + for i, sd in enumerate(self._store): + if sd.id == session_id: + self._store[i] = session_details + break + else: + self._store.append(session_details) + return session_details + + # New WorkspaceClient API: open + def open( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + session_id: str, + resource: SessionDetails, + ) -> SessionDetails: + return self.create_or_replace( + subscription_id, + resource_group_name, + workspace_name, + session_id, + resource, + ) + + def close( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + session_id: str, + ) -> SessionDetails: + sd = self.get(subscription_id, resource_group_name, workspace_name, session_id) + sd.status = "SUCCEEDED" + return sd + + def get( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + session_id: str, + ) -> SessionDetails: + for sd in self._store: + if sd.id == session_id: + return sd + raise KeyError(session_id) + + def list( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + filter: Optional[str] = None, + orderby: Optional[str] = None, + skip: int = 0, + top: int = 100, + ) -> ItemPaged[SessionDetails]: + items = list(self._store) + items = _apply_filter(items, filter) + if orderby: + try: + prop, direction = orderby.split() + if prop == "CreationTime": + items.sort( + key=lambda s: getattr(s, "creation_time", datetime.now(UTC)), + reverse=(direction == "desc"), + ) + except Exception: + pass + return _paged(items[skip : skip + top], page_size=top) + + # New WorkspaceClient API: listv2 (same behavior as list) + def listv2( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + filter: Optional[str] = None, + orderby: Optional[str] = None, + skip: int = 0, + top: int = 100, + ) -> ItemPaged[SessionDetails]: + return self.list( + subscription_id, + resource_group_name, + workspace_name, + filter, + orderby, + skip, + top, + ) + + def jobs_list( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + session_id: str, + filter: Optional[str] = None, + orderby: Optional[str] = None, + skip: int = 0, + top: int = 100, + ) -> ItemPaged[JobDetails]: + jobs = [ + j for j in self._jobs_store if getattr(j, "session_id", None) == session_id + ] + jobs = _apply_filter(jobs, filter) + if orderby: + try: + prop, direction = orderby.split() + if prop == "CreationTime": + jobs.sort( + key=lambda j: getattr(j, "creation_time", datetime.now(UTC)), + reverse=(direction == "desc"), + ) + except Exception: + pass + return _paged(jobs[skip : skip + top], page_size=top) + + +class TopLevelItemsOperations: + def __init__( + self, jobs_store: List[JobDetails], sessions_store: List[SessionDetails] + ) -> None: + self._jobs_store = jobs_store + self._sessions_store = sessions_store + + def list( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + filter: Optional[str] = None, + orderby: Optional[str] = None, + top: int = 100, + skip: int = 0, + ) -> ItemPaged[ItemDetails]: + items: List[ItemDetails] = [] + # Build JobDetails and SessionDetails via mapping-based init to mimic server responses + for j in self._jobs_store: + job_mapping = { + "id": j.id, + "itemType": "Job", + "name": getattr(j, "name", j.id), + "providerId": getattr(j, "provider_id", None), + "target": getattr(j, "target", None), + "creationTime": getattr(j, "creation_time", datetime.now(UTC)), + "jobType": getattr(j, "job_type", None), + # Status is read-only but present in service responses; include if available + "status": getattr(j, "status", None), + } + items.append(JobDetails(job_mapping)) + for s in self._sessions_store: + session_mapping = { + "id": s.id, + "itemType": "Session", + "name": getattr(s, "name", s.id), + "providerId": getattr(s, "provider_id", None), + "target": getattr(s, "target", None), + "creationTime": getattr(s, "creation_time", datetime.now(UTC)), + # Required in model; set a sensible default for mock responses + "jobFailurePolicy": getattr(s, "job_failure_policy", "Abort"), + "status": getattr(s, "status", None), + } + items.append(SessionDetails(session_mapping)) + # Apply filter across heterogeneous items + items = _apply_filter(items, filter) + if orderby: + try: + prop, direction = orderby.split() + if prop == "CreationTime": + items.sort( + key=lambda i: getattr(i, "creation_time", datetime.now(UTC)), + reverse=(direction == "desc"), + ) + except Exception: + pass + return _paged(items[skip : skip + top], page_size=top) + + # New WorkspaceClient API: listv2 + def listv2( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + filter: Optional[str] = None, + orderby: Optional[str] = None, + top: int = 100, + skip: int = 0, + ) -> ItemPaged[ItemDetails]: + return self.list( + subscription_id, + resource_group_name, + workspace_name, + filter, + orderby, + top, + skip, + ) + + +class ProvidersOperations: + def list( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + ) -> ItemPaged: + # Minimal stub: return empty provider list + return _paged([], page_size=100) + + +class QuotasOperations: + def list( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + ) -> ItemPaged: + # Minimal stub: return empty quotas list + return _paged([], page_size=100) + + +class StorageOperations: + def get_sas_uri( + self, + subscription_id: str, + resource_group_name: str, + workspace_name: str, + *, + blob_details: object, + ) -> SasUriResponse: + # Return a dummy SAS URI suitable for tests that might exercise storage + return SasUriResponse({"sasUri": "https://example.com/container?sas-token"}) + + +class MockWorkspaceMgmtClient: + """Mock management client that avoids network calls to ARM/ARG.""" + + def __init__( + self, + credential: Optional[object] = None, + base_url: Optional[str] = None, + user_agent: Optional[str] = None, + ) -> None: + self._credential = credential + self._base_url = base_url + self._user_agent = user_agent + + def close(self) -> None: + """No-op close for mock.""" + pass + + def __enter__(self) -> "MockWorkspaceMgmtClient": + return self + + def __exit__(self, *exc_details) -> None: + pass + + def load_workspace_from_arg( + self, connection_params: WorkspaceConnectionParams + ) -> None: + connection_params.subscription_id = SUBSCRIPTION_ID + connection_params.resource_group = RESOURCE_GROUP + connection_params.location = LOCATION + connection_params.quantum_endpoint = ENDPOINT_URI + + def load_workspace_from_arm( + self, connection_params: WorkspaceConnectionParams + ) -> None: + connection_params.location = LOCATION + connection_params.quantum_endpoint = ENDPOINT_URI + + +class MockWorkspaceClient: + def __init__(self, authentication_policy: Optional[object] = None) -> None: + # in-memory stores + self._jobs_store: List[JobDetails] = [] + self._sessions_store: List[SessionDetails] = [] + # operations grouped under .services to mirror WorkspaceClient + self.services = SimpleNamespace( + jobs=JobsOperations(self._jobs_store), + sessions=SessionsOperations(self._sessions_store, self._jobs_store), + top_level_items=TopLevelItemsOperations( + self._jobs_store, self._sessions_store + ), + providers=ProvidersOperations(), + quotas=QuotasOperations(), + storage=StorageOperations(), + ) + # Mimic WorkspaceClient config shape for tests that inspect policy + self._config = SimpleNamespace(authentication_policy=authentication_policy) + + def __enter__(self) -> "MockWorkspaceClient": + return self + + def __exit__(self, *exc_details) -> None: + pass + + def close(self) -> None: + pass + + +class WorkspaceMock(Workspace): + def __init__(self, **kwargs) -> None: + # Create and pass mock management client to prevent network calls + if "_mgmt_client" not in kwargs: + kwargs["_mgmt_client"] = MockWorkspaceMgmtClient() + super().__init__(**kwargs) + + def _create_client(self) -> WorkspaceClient: # type: ignore[override] + # Pass through the Workspace's auth policy to the mock client + auth_policy = self._connection_params.get_auth_policy() + return MockWorkspaceClient(authentication_policy=auth_policy) + + +def seed_jobs(ws: WorkspaceMock) -> None: + base = datetime.now(UTC) - timedelta(days=10) + samples = [ + JobDetails( + id="j-ionq-1", + name="ionqJobA", + provider_id="ionq", + target="ionq.simulator", + status="Succeeded", + creation_time=base + timedelta(days=1), + session_id="s-ionq-1", + job_type="QuantumComputing", + ), + JobDetails( + id="j-ionq-2", + name="ionqJobB", + provider_id="ionq", + target="ionq.simulator", + status="Failed", + creation_time=base + timedelta(days=2), + session_id="s-ionq-1", + ), + JobDetails( + id="j-qh-1", + name="qhJobA", + provider_id="quantinuum", + target="quantinuum.sim", + status="Cancelled", + creation_time=base + timedelta(days=3), + session_id="s-ionq-2", + job_type="QuantumChemistry", + ), + JobDetails( + id="j-ms-1", + name="msJobA", + provider_id="microsoft", + target="microsoft.estimator", + status="Succeeded", + creation_time=base + timedelta(days=4), + ), + JobDetails( + id="j-ionq-ms-qc", + name="ionqMsQC", + provider_id="ionq", + target="microsoft.estimator", + status="Succeeded", + creation_time=base + timedelta(days=5), + job_type="QuantumComputing", + ), + JobDetails( + id="j-rig-1", + name="rigJobA", + provider_id="rigetti", + target="rigetti.qpu", + status="Succeeded", + ), + ] + for d in samples: + ws._client.services.jobs.create_or_replace( + ws.subscription_id, ws.resource_group, ws.name, job_id=d.id, job_details=d + ) + + +def seed_sessions(ws: WorkspaceMock) -> None: + base = datetime.now(UTC) - timedelta(days=5) + samples = [ + SessionDetails( + id="s-ionq-1", + name="sessionA", + provider_id="ionq", + target="ionq.simulator", + status="Succeeded", + creation_time=base + timedelta(days=1), + ), + SessionDetails( + id="s-ionq-2", + name="sessionB", + provider_id="ionq", + target="ionq.test", + status="Succeeded", + creation_time=base + timedelta(days=2), + ), + ] + for s in samples: + ws._client.services.sessions.create_or_replace( + ws.subscription_id, + ws.resource_group, + ws.name, + session_id=s.id, + session_details=s, + ) + + +def create_default_workspace() -> WorkspaceMock: + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE + ) + seed_jobs(ws) + seed_sessions(ws) + return ws diff --git a/azure-quantum/tests/unit/local/test_job_results.py b/azure-quantum/tests/unit/local/test_job_results.py new file mode 100644 index 00000000..ccbde0e9 --- /dev/null +++ b/azure-quantum/tests/unit/local/test_job_results.py @@ -0,0 +1,336 @@ +## +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +## + +import pytest +from unittest.mock import Mock +from azure.quantum import Job, JobDetails + + +def _mock_job(output_data_format: str, results_as_json_str: str, status: str = "Succeeded") -> Job: + job_details = JobDetails( + id="", + name="", + provider_id="", + target="", + container_uri="", + input_data_format="", + output_data_format=output_data_format, + ) + job_details.status = status + job = Job(workspace=None, job_details=job_details) + + job.has_completed = Mock(return_value=True) + job.wait_until_completed = Mock() + + class DowloadDataMock(object): + def decode(): + str + + pass + + download_data = DowloadDataMock() + download_data.decode = Mock(return_value=results_as_json_str) + job.download_data = Mock(return_value=download_data) + + return job + + +def _get_job_results(output_data_format: str, results_as_json_str: str, status: str = "Succeeded"): + job = _mock_job(output_data_format, results_as_json_str, status) + return job.get_results() + + +def _get_job_results_histogram(output_data_format: str, results_as_json_str: str): + job = _mock_job(output_data_format, results_as_json_str) + return job.get_results_histogram() + + +def _get_job_results_shots(output_data_format: str, results_as_json_str: str): + job = _mock_job(output_data_format, results_as_json_str) + return job.get_results_shots() + + +def test_job_success(): + job_results = _get_job_results( + "test_output_data_format", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + ) + assert len(job_results["Histogram"]) == 4 + + +def test_job_for_microsoft_quantum_results_v1_success(): + job_results = _get_job_results( + "microsoft.quantum-results.v1", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + ) + assert len(job_results.keys()) == 2 + assert job_results["[0]"] == 0.50 + assert job_results["[1]"] == 0.50 + + +def test_job_get_results_with_completed_status(): + job_results = _get_job_results( + "microsoft.quantum-results.v1", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + "Completed", + ) + assert len(job_results.keys()) == 2 + assert job_results["[0]"] == 0.50 + assert job_results["[1]"] == 0.50 + + +def test_job_get_results_with_failed_status_raises_runtime_error(): + with pytest.raises(RuntimeError, match="Cannot retrieve results as job execution failed"): + _get_job_results( + "microsoft.quantum-results.v1", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + "Failed", + ) + + +def test_job_get_results_with_cancelled_status_raises_runtime_error(): + with pytest.raises(RuntimeError, match="Cannot retrieve results as job execution failed"): + _get_job_results( + "microsoft.quantum-results.v1", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + "Cancelled", + ) + + +def test_job_for_microsoft_quantum_results_v1_no_histogram_returns_raw_result(): + job_result_raw = '{"NotHistogramProperty": ["[0]", 0.50, "[1]", 0.50]}' + job_result = _get_job_results("microsoft.quantum-results.v1", job_result_raw) + assert job_result == job_result_raw + + +def test_job_for_microsoft_quantum_results_v1_invalid_histogram_returns_raw_result(): + job_result_raw = '{"NotHistogramProperty": ["[0]", 0.50, "[1]"]}' + job_result = _get_job_results("microsoft.quantum-results.v1", job_result_raw) + assert job_result == job_result_raw + + +def test_job_for_microsoft_quantum_results_v2_success(): + job_results = _get_job_results( + "microsoft.quantum-results.v2", + '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}', + ) + assert len(job_results.keys()) == 2 + assert job_results["[0]"] == 0.50 + assert job_results["[1]"] == 0.50 + + +def test_job_for_microsoft_quantum_results_v2_wrong_type_returns_raw(): + job_result_raw = '{"DataFormat": "microsoft.quantum-results.v1", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}' + job_result = _get_job_results("microsoft.quantum-results.v2", job_result_raw) + assert job_result == job_result_raw + + +def test_job_for_microsoft_quantum_results_v2_invalid_histogram_returns_raw_result(): + job_result_raw = '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]"}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}' + job_result = _get_job_results("microsoft.quantum-results.v2", job_result_raw) + assert job_result == job_result_raw + + +def test_job_for_microsoft_quantum_results_histogram_v2_success(): + job_results = _get_job_results_histogram( + "microsoft.quantum-results.v2", + '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}', + ) + assert len(job_results.keys()) == 2 + assert job_results["[0]"]["count"] == 2 + assert job_results["[1]"]["count"] == 2 + assert job_results["[0]"]["outcome"] == [0] + assert job_results["[1]"]["outcome"] == [1] + + +def test_job_for_microsoft_quantum_results_histogram_batch_v2_success(): + job_results = _get_job_results_histogram( + "microsoft.quantum-results.v2", + '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}, {"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}, {"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}', + ) + assert len(job_results) == 3 + for result in job_results: + assert len(result.keys()) == 2 + assert result["[0]"]["count"] == 2 + assert result["[1]"]["count"] == 2 + assert result["[0]"]["outcome"] == [0] + assert result["[1]"]["outcome"] == [1] + + +def test_job_for_microsoft_quantum_results_histogram_v2_wrong_type_raises_exception(): + try: + _get_job_results_histogram( + "microsoft.quantum-results.v2", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + ) + assert False + except Exception: + assert True + + +def test_job_for_microsoft_quantum_results_shots_v2_success(): + job_results = _get_job_results_shots( + "microsoft.quantum-results.v2", + '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}', + ) + assert len(job_results) == 4 + assert job_results[0] == [0] + assert job_results[1] == [1] + assert job_results[2] == [1] + assert job_results[3] == [0] + + +def test_job_for_microsoft_quantum_results_shots_batch_v2_success(): + job_results = _get_job_results_shots( + "microsoft.quantum-results.v2", + '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}, {"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}, {"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}', + ) + assert len(job_results) == 3 + for i in range(3): + assert len(job_results[i]) == 4 + assert job_results[i][0] == [0] + assert job_results[i][1] == [1] + assert job_results[i][2] == [1] + assert job_results[i][3] == [0] + + +def test_job_for_microsoft_quantum_results_histogram_v2_tuple_success(): + output = """{ + \"DataFormat\": \"microsoft.quantum-results.v2\", + \"Results\": [ + { + \"Histogram\": [ + { + \"Outcome\": { + \"Item1\": [1, 0], + \"Item2\": { + \"Item1\": -2.71, + \"Item2\": 67 + }, + \"Item3\": [ + { + \"Item1\": 6, + \"Item2\": true + }, + { + \"Item1\": 12, + \"Item2\": false + } + ] + }, + \"Display\": \"([1, 0], (-2.71, 67), [(6, true), (12, false)])\", + \"Count\": 1 + }, + { + \"Outcome\": [1, 0], + \"Display\": \"[1, 0]\", + \"Count\": 1 + }, + { + \"Outcome\": [1], + \"Display\": \"[1]\", + \"Count\": 1 + } + ], + \"Shots\": [ + { + \"Item1\": [1, 0], + \"Item2\": { + \"Item1\": -2.71, + \"Item2\": 67 + }, + \"Item3\": [ + { + \"Item1\": 6, + \"Item2\": true + }, + { + \"Item1\": 12, + \"Item2\": false + } + ] + }, + [1, 0], + [1] + ] + } + ] +}""" + job_results = _get_job_results_histogram("microsoft.quantum-results.v2", output) + assert len(job_results.keys()) == 3 + assert job_results["[1, 0]"]["count"] == 1 + assert job_results["[1]"]["count"] == 1 + assert job_results["([1, 0], (-2.71, 67), [(6, true), (12, false)])"]["count"] == 1 + assert job_results["([1, 0], (-2.71, 67), [(6, true), (12, false)])"][ + "outcome" + ] == ([1, 0], (-2.71, 67), [(6, True), (12, False)]) + assert job_results["[1]"]["outcome"] == [1] + assert job_results["[1, 0]"]["outcome"] == [1, 0] + + +def test_job_for_microsoft_quantum_results_shots_v2_tuple_success(): + output = """{ + \"DataFormat\": \"microsoft.quantum-results.v2\", + \"Results\": [ + { + \"Histogram\": [ + { + \"Outcome\": { + \"Item1\": [ + 1, + 0 + ], + \"Item2\": { + \"Item1\": -2.71, + \"Item2\": 67 + } + }, + \"Display\": \"([1, 0], (-2.71, 67))\", + \"Count\": 1 + }, + { + \"Outcome\": [1, 0], + \"Display\": \"[1, 0]\", + \"Count\": 1 + }, + { + \"Outcome\": [1], + \"Display\": \"[1]\", + \"Count\": 1 + } + ], + \"Shots\": [ + { + \"Item1\": [ + 1, + 0 + ], + \"Item2\": { + \"Item1\": -2.71, + \"Item2\": 67 + } + }, + [1, 0], + [1] + ] + } + ] + }""" + job_results = _get_job_results_shots("microsoft.quantum-results.v2", output) + assert len(job_results) == 3 + assert job_results[0] == ([1, 0], (-2.71, 67)) + assert job_results[1] == [1, 0] + assert job_results[2] == [1] + + +def test_job_for_microsoft_quantum_results_shots_v2_wrong_type_raises_exception(): + try: + _get_job_results_shots( + "microsoft.quantum-results.v2", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + ) + assert False + except Exception: + assert True diff --git a/azure-quantum/tests/unit/local/test_mgmt_client.py b/azure-quantum/tests/unit/local/test_mgmt_client.py new file mode 100644 index 00000000..cad4e4c8 --- /dev/null +++ b/azure-quantum/tests/unit/local/test_mgmt_client.py @@ -0,0 +1,659 @@ +## +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +## + +import pytest +from unittest.mock import MagicMock, patch +from http import HTTPStatus +from azure.core.exceptions import HttpResponseError +from azure.quantum._mgmt_client import WorkspaceMgmtClient +from azure.quantum._workspace_connection_params import WorkspaceConnectionParams +from azure.quantum._constants import ConnectionConstants +from common import ( + SUBSCRIPTION_ID, + RESOURCE_GROUP, + WORKSPACE, + LOCATION, + ENDPOINT_URI, +) + + +def test_init_creates_client(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + + client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + assert client._credential == mock_credential + assert client._base_url == base_url + assert client._client is not None + assert len(client._policies) == 5 + + +def test_init_without_user_agent(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + + client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url + ) + + assert client._credential == mock_credential + assert client._base_url == base_url + assert client._client is not None + + +def test_context_manager_enter(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + with patch.object(mgmt_client._client, '__enter__', return_value=mgmt_client._client): + result = mgmt_client.__enter__() + assert result == mgmt_client + + +def test_context_manager_exit(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + with patch.object(mgmt_client._client, '__exit__') as mock_exit: + mgmt_client.__exit__(None, None, None) + mock_exit.assert_called_once() + + +def test_close(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + with patch.object(mgmt_client._client, 'close') as mock_close: + mgmt_client.close() + mock_close.assert_called_once() + + +def test_load_workspace_from_arg_success(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + connection_params.subscription_id = None + connection_params.location = None + connection_params.quantum_endpoint = None + + mgmt_client.load_workspace_from_arg(connection_params) + + assert connection_params.subscription_id == SUBSCRIPTION_ID + assert connection_params.resource_group == RESOURCE_GROUP + assert connection_params.workspace_name == WORKSPACE + assert connection_params.location == LOCATION + assert connection_params.quantum_endpoint == ENDPOINT_URI + + +def test_load_workspace_from_arg_with_resource_group_filter(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE, + resource_group=RESOURCE_GROUP + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert RESOURCE_GROUP in str(request.content) + + +def test_load_workspace_from_arg_with_location_filter(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE, + location=LOCATION + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert LOCATION in str(request.content) + + +def test_load_workspace_from_arg_with_subscription_filter(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE, + subscription_id=SUBSCRIPTION_ID + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + request_body = request.content + assert 'subscriptions' in request_body + + +def test_load_workspace_from_arg_no_workspace_name(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams() + + with pytest.raises(ValueError, match="Workspace name must be specified"): + mgmt_client.load_workspace_from_arg(connection_params) + + +def test_load_workspace_from_arg_no_matching_workspace(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = {'data': []} + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="No matching workspace found"): + mgmt_client.load_workspace_from_arg(connection_params) + + +def test_load_workspace_from_arg_multiple_workspaces(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [ + { + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }, + { + 'name': WORKSPACE, + 'subscriptionId': 'another-sub-id', + 'resourceGroup': 'another-rg', + 'location': 'westus', + 'endpointUri': 'https://another.endpoint.com/' + } + ] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Multiple Azure Quantum workspaces found"): + mgmt_client.load_workspace_from_arg(connection_params) + + +def test_load_workspace_from_arg_incomplete_workspace_data(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Failed to retrieve complete workspace details"): + mgmt_client.load_workspace_from_arg(connection_params) + + +def test_load_workspace_from_arg_request_exception(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + with patch.object(mgmt_client._client, 'send_request', side_effect=Exception("Network error")): + with pytest.raises(RuntimeError, match="Could not load workspace details from Azure Resource Graph"): + mgmt_client.load_workspace_from_arg(connection_params) + + +def test_load_workspace_from_arm_success(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + connection_params.location = None + connection_params.quantum_endpoint = None + + mgmt_client.load_workspace_from_arm(connection_params) + + assert connection_params.location == LOCATION + assert connection_params.quantum_endpoint == ENDPOINT_URI + + +def test_load_workspace_from_arm_missing_required_params(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE + ) + + with pytest.raises(ValueError, match="Missing required connection parameters"): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_workspace_not_found(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_error = HttpResponseError() + mock_error.status_code = HTTPStatus.NOT_FOUND + + with patch.object(mgmt_client._client, 'send_request', side_effect=mock_error): + with pytest.raises(ValueError, match="not found in resource group"): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_http_error(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_error = HttpResponseError() + mock_error.status_code = HTTPStatus.FORBIDDEN + + with patch.object(mgmt_client._client, 'send_request', side_effect=mock_error): + with pytest.raises(HttpResponseError): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_missing_location(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Failed to retrieve location"): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_missing_endpoint(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': {} + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Failed to retrieve endpoint uri"): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_request_exception(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + with patch.object(mgmt_client._client, 'send_request', side_effect=Exception("Network error")): + with pytest.raises(RuntimeError, match="Could not load workspace details from ARM"): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_uses_custom_api_version(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_version="2024-01-01" + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arm(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert "2024-01-01" in request.url + + +def test_load_workspace_from_arm_uses_default_api_version(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arm(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert ConnectionConstants.DEFAULT_WORKSPACE_API_VERSION in request.url + + +def test_load_workspace_from_arg_constructs_correct_url(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert "/providers/Microsoft.ResourceGraph/resources" in request.url + assert ConnectionConstants.DEFAULT_ARG_API_VERSION in request.url + + +def test_load_workspace_from_arm_constructs_correct_url(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arm(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert f"/subscriptions/{SUBSCRIPTION_ID}" in request.url + assert f"/resourceGroups/{RESOURCE_GROUP}" in request.url + assert f"/providers/Microsoft.Quantum/workspaces/{WORKSPACE}" in request.url diff --git a/azure-quantum/tests/unit/local/test_pagination.py b/azure-quantum/tests/unit/local/test_pagination.py new file mode 100644 index 00000000..507a0e7c --- /dev/null +++ b/azure-quantum/tests/unit/local/test_pagination.py @@ -0,0 +1,339 @@ +## +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +## + +from datetime import datetime, UTC, timedelta + +from mock_client import create_default_workspace + + +def test_list_jobs_basic(): + ws = create_default_workspace() + jobs = list(ws.list_jobs()) + assert all(j.item_type == "Job" for j in jobs) + assert len(jobs) >= 4 + + +def test_list_jobs_filters(): + ws = create_default_workspace() + # name prefix + jobs = list(ws.list_jobs(name_match="ionq")) + assert jobs and all(j.details.name.startswith("ionq") for j in jobs) + # provider + jobs = list(ws.list_jobs(provider=["ionq"])) + assert jobs and all(j.details.provider_id == "ionq" for j in jobs) + # target + jobs = list(ws.list_jobs(target=["microsoft.estimator", "microsoft.dft"])) + assert all( + j.details.target in {"microsoft.estimator", "microsoft.dft"} for j in jobs + ) + # status + jobs = list(ws.list_jobs(status=["Failed", "Cancelled"])) + assert all(j.details.status in {"Failed", "Cancelled"} for j in jobs) + + +def test_list_jobs_created_window_and_ordering(): + ws = create_default_workspace() + after = datetime.now(UTC) - timedelta(days=9) + before = datetime.now(UTC) + timedelta(days=1) + # asc + asc = list( + ws.list_jobs( + created_after=after, + created_before=before, + orderby_property="CreationTime", + is_asc=True, + ) + ) + assert all( + j.details.creation_time.date() >= after.date() + and j.details.creation_time.date() <= before.date() + for j in asc + ) + for a, b in zip(asc, asc[1:]): + assert a.details.creation_time <= b.details.creation_time + # desc + desc = list( + ws.list_jobs( + created_after=after, + created_before=before, + orderby_property="CreationTime", + is_asc=False, + ) + ) + for a, b in zip(desc, desc[1:]): + assert a.details.creation_time >= b.details.creation_time + # missing creation_time default handling ensures item is included and sortable + all_jobs = list(ws.list_jobs(orderby_property="CreationTime", is_asc=True)) + assert any(j.details.id == "j-rig-1" for j in all_jobs) + + +def test_list_jobs_paging_basic(): + ws = create_default_workspace() + jobs = ws.list_jobs(orderby_property="CreationTime", is_asc=True) + # Ensure iterable and ordered + jobs_list = list(jobs) + assert len(jobs_list) >= 1 + for a, b in zip(jobs_list, jobs_list[1:]): + assert a.details.creation_time <= b.details.creation_time + + +def test_list_sessions_basic_and_filters(): + ws = create_default_workspace() + sessions = list(ws.list_sessions()) + assert all(s.item_type == "Session" for s in sessions) + # provider filter + f = list(ws.list_sessions(provider=["ionq"])) + assert f and all(s._details.provider_id == "ionq" for s in f) + # target filter + t = list(ws.list_sessions(target=["ionq.test", "ionq.simulator"])) + assert t and all(s._details.target in {"ionq.test", "ionq.simulator"} for s in t) + # status filter + st = list(ws.list_sessions(status=["Succeeded"])) + assert st and all(s._details.status == "Succeeded" for s in st) + # multi-value ORs + prov_or = ws.list_sessions(provider=["ionq", "quantinuum"]) + assert prov_or and all( + s._details.provider_id in {"ionq", "quantinuum"} for s in prov_or + ) + st_or = ws.list_sessions(status=["Succeeded", "WAITING"]) + assert st_or and all(s._details.status in {"Succeeded", "WAITING"} for s in st_or) + + +def test_list_sessions_created_ordering(): + ws = create_default_workspace() + before = datetime.now(UTC) + timedelta(days=1) + asc = list( + ws.list_sessions( + created_before=before, orderby_property="CreationTime", is_asc=True + ) + ) + for a, b in zip(asc, asc[1:]): + assert a.details.creation_time <= b.details.creation_time + desc = list( + ws.list_sessions( + created_before=before, orderby_property="CreationTime", is_asc=False + ) + ) + for a, b in zip(desc, desc[1:]): + assert a.details.creation_time >= b.details.creation_time + + +def test_list_session_jobs_filters_and_order(): + ws = create_default_workspace() + sessions = list(ws.list_sessions()) + assert sessions + sid = sessions[0].id + jobs = list(ws.list_session_jobs(session_id=sid)) + assert jobs and all( + j.item_type == "Job" and j._details.session_id == sid for j in jobs + ) + jn = list(ws.list_session_jobs(session_id=sid, name_match="ionqJob")) + assert all(j.details.name.startswith("ionqJob") for j in jn) + js = list(ws.list_session_jobs(session_id=sid, status=["Succeeded"])) + assert all(j.details.status == "Succeeded" for j in js) + asc = list( + ws.list_session_jobs( + session_id=sid, orderby_property="CreationTime", is_asc=True + ) + ) + for a, b in zip(asc, asc[1:]): + assert a.details.creation_time <= b.details.creation_time + desc = list( + ws.list_session_jobs( + session_id=sid, orderby_property="CreationTime", is_asc=False + ) + ) + for a, b in zip(desc, desc[1:]): + assert a.details.creation_time >= b.details.creation_time + + +def test_list_top_level_items_basic_and_filters(): + ws = create_default_workspace() + items = list(ws.list_top_level_items()) + assert all(i.workspace.subscription_id == ws.subscription_id for i in items) + # name filters + i1 = list(ws.list_top_level_items(name_match="ionq")) + assert all(it.details.name.startswith("ionq") for it in i1) + # exact-case only; mixed-case not supported per API + # provider + # combined provider AND status AND window + before = datetime.now(UTC) + timedelta(days=1) + combo = list( + ws.list_sessions(provider=["ionq"], status=["Succeeded"], created_before=before) + ) + assert combo and all( + s._details.provider_id == "ionq" and s._details.status == "Succeeded" + for s in combo + ) + prov = list(ws.list_top_level_items(provider=["ionq"])) + assert prov and all(it.details.provider_id == "ionq" for it in prov) + # target + tgt = list(ws.list_top_level_items(target=["microsoft.estimator", "microsoft.dft"])) + assert all( + it.details.target in {"microsoft.estimator", "microsoft.dft"} for it in tgt + ) + # status + st = list(ws.list_top_level_items(status=["Failed", "Cancelled"])) + assert all(it.details.status in {"Failed", "Cancelled"} for it in st) + # combined filters: provider AND target; with seeded AND-match expect results + combo = list( + ws.list_top_level_items(provider=["ionq"], target=["microsoft.estimator"]) + ) + assert combo and all( + it.details.provider_id == "ionq" and it.details.target == "microsoft.estimator" + for it in combo + ) + + # case sensitivity: lower-case item_type should return empty + combo_case = list(ws.list_top_level_items(item_type=["job"])) + assert len(combo_case) == 0 + + # multi-value OR grouping for item_type should return both types + both_types = list(ws.list_top_level_items(item_type=["Job", "Session"])) + assert ( + both_types + and any(it.item_type == "Job" for it in both_types) + and any(it.item_type == "Session" for it in both_types) + ) + + # multi-value OR grouping for job_type should include both QuantumComputing and QuantumChemistry + jt_multi = list( + ws.list_top_level_items(job_type=["QuantumComputing", "QuantumChemistry"]) + ) + assert ( + jt_multi + and any( + getattr(it.details, "job_type", None) == "QuantumComputing" + for it in jt_multi + ) + and any( + getattr(it.details, "job_type", None) == "QuantumChemistry" + for it in jt_multi + ) + ) + + # date boundary tests: created_after/on boundary includes items; created_before/on boundary includes items + # choose a boundary based on a known seeded item creation_time + boundary_date = next( + it.details.creation_time.date() for it in items if it.details.name == "msJobA" + ) + after_inclusive = list( + ws.list_top_level_items( + created_after=datetime.combine( + boundary_date, datetime.min.time(), tzinfo=UTC + ) + ) + ) + assert any( + it.details.creation_time.date() >= boundary_date for it in after_inclusive + ) + before_inclusive = list( + ws.list_top_level_items( + created_before=datetime.combine( + boundary_date, datetime.min.time(), tzinfo=UTC + ) + ) + ) + assert any( + it.details.creation_time.date() <= boundary_date for it in before_inclusive + ) + # job_type + provider + target (AND semantics); with seeded combo expect non-empty + jt_combo = list( + ws.list_top_level_items( + job_type=["QuantumComputing"], + provider=["ionq"], + target=["microsoft.estimator"], + ) + ) + assert jt_combo and all( + getattr(it.details, "job_type", None) == "QuantumComputing" + and it.details.provider_id == "ionq" + and it.details.target == "microsoft.estimator" + for it in jt_combo + ) + # negative test: no match + none_items = list(ws.list_top_level_items(provider=["no-provider"])) + assert len(none_items) == 0 + + +def test_list_top_level_items_created_ordering(): + ws = create_default_workspace() + after = datetime.now(UTC) - timedelta(days=15) + asc = list( + ws.list_top_level_items( + created_after=after, orderby_property="CreationTime", is_asc=True + ) + ) + for a, b in zip(asc, asc[1:]): + assert a.details.creation_time <= b.details.creation_time + desc = list( + ws.list_top_level_items( + created_after=after, orderby_property="CreationTime", is_asc=False + ) + ) + for a, b in zip(desc, desc[1:]): + assert a.details.creation_time >= b.details.creation_time + # Ascending with created_after boundary + start = datetime.now(UTC) - timedelta(days=365) + items_after = list( + ws.list_top_level_items( + created_after=start, orderby_property="CreationTime", is_asc=True + ) + ) + assert items_after + prev = None + for it in items_after: + assert it.details.creation_time.date() >= start.date() + if prev is None: + prev = it.details.creation_time + else: + assert it.details.creation_time >= prev + prev = it.details.creation_time + + +def test_filter_string_emission(): + ws = create_default_workspace() + # pylint: disable=protected-access + filter_string = ws._create_filter( + job_name="name", + item_type=["Session", "Job"], + job_type=["Regular", "Chemistry"], + provider_ids=["ionq", "quantinuum"], + target=["ionq.sim", "quantinuum,sim"], + status=["Completed", "Failed"], + created_after=datetime(2024, 10, 1), + created_before=datetime(2024, 11, 1), + ) + # pylint: enable=protected-access + expected = ( + "startswith(Name, 'name') and (ItemType eq 'Session' or ItemType eq 'Job') and " + "(JobType eq 'Regular' or JobType eq 'Chemistry') and (ProviderId eq 'ionq' or ProviderId eq 'quantinuum') and " + "(Target eq 'ionq.sim' or Target eq 'quantinuum,sim') and (State eq 'Completed' or State eq 'Failed') and " + "CreationTime ge 2024-10-01 and CreationTime le 2024-11-01" + ) + assert filter_string == expected + + +def test_orderby_emission_and_validation(): + ws = create_default_workspace() + props = [ + "Name", + "ItemType", + "JobType", + "ProviderId", + "Target", + "State", + "CreationTime", + ] + # pylint: disable=protected-access + for p in props: + assert ws._create_orderby(p, True) == f"{p} asc" + assert ws._create_orderby(p, False) == f"{p} desc" + try: + ws._create_orderby("test", True) + assert False, "Expected ValueError for invalid property" + except ValueError: + pass + # pylint: enable=protected-access diff --git a/azure-quantum/tests/unit/local/test_session.py b/azure-quantum/tests/unit/local/test_session.py new file mode 100644 index 00000000..55b6a0da --- /dev/null +++ b/azure-quantum/tests/unit/local/test_session.py @@ -0,0 +1,41 @@ +## +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +## + +from mock_client import create_default_workspace + + +def test_list_top_level_items_includes_jobs_and_sessions(): + ws = create_default_workspace() + items = list(ws.list_top_level_items()) + assert items + item_types = {type(it).__name__ for it in items} + assert "Job" in item_types + assert "Session" in item_types + + +def test_list_sessions_basic(): + ws = create_default_workspace() + sessions = list(ws.list_sessions()) + assert sessions + assert all(type(s).__name__ == "Session" for s in sessions) + + +def test_get_session_returns_matching_details_and_jobs(): + ws = create_default_workspace() + # Choose a known session from the seeded data + sessions = list(ws.list_sessions()) + assert sessions + sid = sessions[0].id + + s = ws.get_session(session_id=sid) + assert s + assert s.id == sid + assert s.details.id == sid + + # Verify session-scoped jobs are returned and have matching session_id + jobs = list(s.list_jobs()) + assert jobs + assert all(j.item_type == "Job" for j in jobs) + assert all(getattr(j._details, "session_id", None) == sid for j in jobs) diff --git a/azure-quantum/tests/unit/local/test_workspace.py b/azure-quantum/tests/unit/local/test_workspace.py new file mode 100644 index 00000000..9016ed2f --- /dev/null +++ b/azure-quantum/tests/unit/local/test_workspace.py @@ -0,0 +1,490 @@ +## +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +## + +import os +from unittest import mock +from azure.quantum._constants import EnvironmentVariables, ConnectionConstants +from azure.core.credentials import AzureKeyCredential +from azure.core.pipeline.policies import AzureKeyCredentialPolicy +from azure.identity import EnvironmentCredential + +from mock_client import WorkspaceMock, MockWorkspaceMgmtClient +from common import ( + SUBSCRIPTION_ID, + RESOURCE_GROUP, + WORKSPACE, + LOCATION, + STORAGE, + API_KEY, + ENDPOINT_URI, +) + +SIMPLE_RESOURCE_ID = ConnectionConstants.VALID_RESOURCE_ID( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, +) + +SIMPLE_CONNECTION_STRING = ConnectionConstants.VALID_CONNECTION_STRING( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_key=API_KEY, + quantum_endpoint=ConnectionConstants.GET_QUANTUM_PRODUCTION_ENDPOINT(LOCATION), +) + +SIMPLE_CONNECTION_STRING_V2 = ConnectionConstants.VALID_CONNECTION_STRING( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_key=API_KEY, + quantum_endpoint=ConnectionConstants.GET_QUANTUM_PRODUCTION_ENDPOINT_v2(LOCATION) +) + + +def test_create_workspace_instance_valid(): + def assert_all_required_params(ws: WorkspaceMock): + assert ws.subscription_id == SUBSCRIPTION_ID + assert ws.resource_group == RESOURCE_GROUP + assert ws.name == WORKSPACE + assert ws.location == LOCATION + assert ws._connection_params.quantum_endpoint == ENDPOINT_URI + + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + ) + assert_all_required_params(ws) + + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + storage=STORAGE, + ) + assert_all_required_params(ws) + assert ws.storage == STORAGE + + ws = WorkspaceMock( + resource_id=SIMPLE_RESOURCE_ID, + ) + assert_all_required_params(ws) + + ws = WorkspaceMock( + resource_id=SIMPLE_RESOURCE_ID, + storage=STORAGE, + ) + assert_all_required_params(ws) + assert ws.storage == STORAGE + + ws = WorkspaceMock( + name=WORKSPACE, + ) + assert_all_required_params(ws) + + ws = WorkspaceMock( + name=WORKSPACE, + storage=STORAGE, + ) + assert_all_required_params(ws) + assert ws.storage == STORAGE + + ws = WorkspaceMock( + name=WORKSPACE, + location=LOCATION, + ) + assert_all_required_params(ws) + + ws = WorkspaceMock( + name=WORKSPACE, + subscription_id=SUBSCRIPTION_ID, + ) + assert_all_required_params(ws) + + ws = WorkspaceMock( + name=WORKSPACE, + subscription_id=SUBSCRIPTION_ID, + location=LOCATION, + ) + assert_all_required_params(ws) + + ws = WorkspaceMock( + name=WORKSPACE, + resource_group=RESOURCE_GROUP, + ) + assert_all_required_params(ws) + + ws = WorkspaceMock( + name=WORKSPACE, + resource_group=RESOURCE_GROUP, + location=LOCATION, + ) + assert_all_required_params(ws) + + +def test_create_workspace_locations(): + # Location name should be normalized + _mgmt_client = MockWorkspaceMgmtClient() + def mock_load_workspace_from_arm(connection_params): + connection_params.location = "East US" + connection_params.quantum_endpoint = ENDPOINT_URI + _mgmt_client.load_workspace_from_arm = mock_load_workspace_from_arm + + ws = WorkspaceMock( + name=WORKSPACE, + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + location="East US", + _mgmt_client=_mgmt_client, + ) + assert ws.location == "eastus" + + +def test_env_connection_string(): + with mock.patch.dict(os.environ): + # Clear env vars then set connection string + os.environ.clear() + os.environ[EnvironmentVariables.CONNECTION_STRING] = SIMPLE_CONNECTION_STRING + + workspace = WorkspaceMock() + assert workspace.location == LOCATION + assert workspace.subscription_id == SUBSCRIPTION_ID + assert workspace.name == WORKSPACE + assert workspace.resource_group == RESOURCE_GROUP + assert isinstance(workspace.credential, AzureKeyCredential) + assert workspace.credential.key == API_KEY + # pylint: disable=protected-access + assert isinstance( + workspace._client._config.authentication_policy, AzureKeyCredentialPolicy + ) + auth_policy = workspace._client._config.authentication_policy + assert auth_policy._name == ConnectionConstants.QUANTUM_API_KEY_HEADER + assert id(auth_policy._credential) == id(workspace.credential) + + +def test_workspace_from_connection_string(): + with mock.patch.dict(os.environ): + os.environ.clear() + workspace = WorkspaceMock.from_connection_string(SIMPLE_CONNECTION_STRING) + assert workspace.location == LOCATION + assert isinstance(workspace.credential, AzureKeyCredential) + assert workspace.credential.key == API_KEY + # pylint: disable=protected-access + assert isinstance( + workspace._client._config.authentication_policy, AzureKeyCredentialPolicy + ) + auth_policy = workspace._client._config.authentication_policy + assert auth_policy._name == ConnectionConstants.QUANTUM_API_KEY_HEADER + assert id(auth_policy._credential) == id(workspace.credential) + + # Ensure env var overrides behave as original tests expect + with mock.patch.dict(os.environ): + os.environ.clear() + + wrong_subscription_id = "00000000-2BAD-2BAD-2BAD-000000000000" + wrong_resource_group = "wrongrg" + wrong_workspace = "wrong-workspace" + wrong_location = "westus" + + wrong_connection_string = ConnectionConstants.VALID_CONNECTION_STRING( + subscription_id=wrong_subscription_id, + resource_group=wrong_resource_group, + workspace_name=wrong_workspace, + api_key=API_KEY, + quantum_endpoint=ConnectionConstants.GET_QUANTUM_PRODUCTION_ENDPOINT( + wrong_location + ), + ) + + os.environ[EnvironmentVariables.CONNECTION_STRING] = wrong_connection_string + os.environ[EnvironmentVariables.LOCATION] = LOCATION + os.environ[EnvironmentVariables.SUBSCRIPTION_ID] = SUBSCRIPTION_ID + os.environ[EnvironmentVariables.RESOURCE_GROUP] = RESOURCE_GROUP + os.environ[EnvironmentVariables.WORKSPACE_NAME] = WORKSPACE + + workspace = WorkspaceMock() + assert workspace.location == LOCATION + assert workspace.subscription_id == SUBSCRIPTION_ID + assert workspace.resource_group == RESOURCE_GROUP + assert workspace.name == WORKSPACE + assert isinstance(workspace.credential, AzureKeyCredential) + + # If a credential is passed, it should be used + workspace = WorkspaceMock(credential=EnvironmentCredential()) + assert isinstance(workspace.credential, EnvironmentCredential) + + # Parameter connection string should override env var + os.environ.clear() + os.environ[EnvironmentVariables.CONNECTION_STRING] = wrong_connection_string + connection_string = ConnectionConstants.VALID_CONNECTION_STRING( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_key=API_KEY, + quantum_endpoint=ConnectionConstants.GET_QUANTUM_PRODUCTION_ENDPOINT( + LOCATION + ), + ) + workspace = WorkspaceMock.from_connection_string(connection_string) + assert workspace.location == LOCATION + assert workspace.subscription_id == SUBSCRIPTION_ID + assert workspace.resource_group == RESOURCE_GROUP + assert workspace.name == WORKSPACE + + # Bad env var connection string should not be parsed if not needed + os.environ.clear() + os.environ[EnvironmentVariables.CONNECTION_STRING] = "bad-connection-string" + connection_string = ConnectionConstants.VALID_CONNECTION_STRING( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_key=API_KEY, + quantum_endpoint=ConnectionConstants.GET_QUANTUM_PRODUCTION_ENDPOINT( + LOCATION + ), + ) + workspace = WorkspaceMock.from_connection_string(connection_string) + assert workspace.location == LOCATION + assert workspace.subscription_id == SUBSCRIPTION_ID + assert workspace.resource_group == RESOURCE_GROUP + assert workspace.name == WORKSPACE + +def test_workspace_from_connection_string_v2(): + """Test that v2 QuantumEndpoint format is correctly parsed.""" + with mock.patch.dict( + os.environ, + clear=True + ): + workspace = WorkspaceMock.from_connection_string(SIMPLE_CONNECTION_STRING_V2) + assert workspace.location == LOCATION + assert workspace.subscription_id == SUBSCRIPTION_ID + assert workspace.resource_group == RESOURCE_GROUP + assert workspace.name == WORKSPACE + assert isinstance(workspace.credential, AzureKeyCredential) + assert workspace.credential.key == API_KEY + # pylint: disable=protected-access + assert isinstance( + workspace._client._config.authentication_policy, + AzureKeyCredentialPolicy) + auth_policy = workspace._client._config.authentication_policy + assert auth_policy._name == ConnectionConstants.QUANTUM_API_KEY_HEADER + assert id(auth_policy._credential) == id(workspace.credential) + +def test_workspace_from_connection_string_v2_dogfood(): + """Test v2 QuantumEndpoint with dogfood environment.""" + canary_location = "eastus2euap" + dogfood_connection_string_v2 = ConnectionConstants.VALID_CONNECTION_STRING( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_key=API_KEY, + quantum_endpoint=ConnectionConstants.GET_QUANTUM_DOGFOOD_ENDPOINT_v2(canary_location) + ) + + with mock.patch.dict(os.environ, clear=True): + workspace = WorkspaceMock.from_connection_string(dogfood_connection_string_v2) + assert workspace.location == canary_location + assert workspace.subscription_id == SUBSCRIPTION_ID + assert workspace.resource_group == RESOURCE_GROUP + assert workspace.name == WORKSPACE + assert isinstance(workspace.credential, AzureKeyCredential) + assert workspace.credential.key == API_KEY + +def test_env_connection_string_v2(): + """Test v2 QuantumEndpoint from environment variable.""" + with mock.patch.dict(os.environ): + os.environ.clear() + os.environ[EnvironmentVariables.CONNECTION_STRING] = SIMPLE_CONNECTION_STRING_V2 + + workspace = WorkspaceMock() + assert workspace.location == LOCATION + assert workspace.subscription_id == SUBSCRIPTION_ID + assert workspace.name == WORKSPACE + assert workspace.resource_group == RESOURCE_GROUP + assert isinstance(workspace.credential, AzureKeyCredential) + assert workspace.credential.key == API_KEY + # pylint: disable=protected-access + assert isinstance( + workspace._client._config.authentication_policy, + AzureKeyCredentialPolicy) + auth_policy = workspace._client._config.authentication_policy + assert auth_policy._name == ConnectionConstants.QUANTUM_API_KEY_HEADER + assert id(auth_policy._credential) == id(workspace.credential) + +def test_create_workspace_instance_invalid(): + def assert_value_error(exception: Exception): + assert "Azure Quantum workspace not fully specified." in exception.args[0] + + with mock.patch.dict(os.environ): + os.environ.clear() + + # missing workspace name + try: + WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=None + ) + assert False, "Expected ValueError" + except ValueError as e: + assert_value_error(e) + + # provide only subscription id and resource group + try: + WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + ) + assert False, "Expected ValueError" + except ValueError as e: + assert_value_error(e) + + # missing everything + try: + WorkspaceMock() + assert False, "Expected ValueError" + except ValueError as e: + assert_value_error(e) + + # invalid resource id + try: + WorkspaceMock(location=LOCATION, resource_id="invalid/resource/id") + assert False, "Expected ValueError" + except ValueError as e: + assert "Invalid resource id" in e.args[0] + + +def test_workspace_user_agent_appid(): + app_id = "MyEnvVarAppId" + user_agent = "MyUserAgent" + with mock.patch.dict(os.environ): + os.environ.clear() + + # no UserAgent parameter and no EnvVar AppId + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + ) + assert ws.user_agent is None + + # with UserAgent parameter and no EnvVar AppId + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + user_agent=user_agent, + ) + assert ws.user_agent == user_agent + + # append with no UserAgent parameter and no EnvVar AppId + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + ) + ws.append_user_agent("featurex") + assert ws.user_agent == "featurex" + + # set EnvVar AppId for remaining cases + os.environ[EnvironmentVariables.USER_AGENT_APPID] = app_id + + # no UserAgent parameter and with EnvVar AppId + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + ) + assert ws.user_agent == app_id + + # with UserAgent parameter and EnvVar AppId + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + user_agent=user_agent, + ) + assert ws.user_agent == f"{app_id} {user_agent}" + + # append with UserAgent parameter and with EnvVar AppId + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + user_agent=user_agent, + ) + ws.append_user_agent("featurex") + assert ws.user_agent == f"{app_id} {user_agent}-featurex" + + ws.append_user_agent(None) + assert ws.user_agent == app_id + +def test_workspace_context_manager(): + """Test that Workspace can be used as a context manager""" + with WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + ) as ws: + # Verify workspace is properly initialized + assert ws.subscription_id == SUBSCRIPTION_ID + assert ws.resource_group == RESOURCE_GROUP + assert ws.name == WORKSPACE + assert ws.location == LOCATION + + # Verify internal clients are accessible + assert ws._client is not None + assert ws._mgmt_client is not None + +def test_workspace_context_manager_calls_enter_exit(): + """Test that __enter__ and __exit__ are called on internal clients""" + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + ) + + # Mock the internal clients' __enter__ and __exit__ methods + ws._client.__enter__ = mock.MagicMock(return_value=ws._client) + ws._client.__exit__ = mock.MagicMock(return_value=None) + ws._mgmt_client.__enter__ = mock.MagicMock(return_value=ws._mgmt_client) + ws._mgmt_client.__exit__ = mock.MagicMock(return_value=None) + + # Use workspace as context manager + with ws as context_ws: + # Verify __enter__ was called on both clients + ws._client.__enter__.assert_called_once() + ws._mgmt_client.__enter__.assert_called_once() + + # Verify context manager returns the workspace instance + assert context_ws is ws + + # Verify __exit__ was called on both clients after exiting context + ws._client.__exit__.assert_called_once() + ws._mgmt_client.__exit__.assert_called_once() + + +def test_get_container_uri_uses_linked_storage_sas_when_storage_none(): + """When storage is None, get_container_uri should use linked storage via service SAS.""" + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + ) + assert ws.storage is None + + with mock.patch( + "azure.quantum.storage.ContainerClient.from_container_url", + return_value=mock.MagicMock(), + ): + with mock.patch( + "azure.quantum.storage.create_container_using_client", + return_value=None, + ): + uri = ws.get_container_uri(job_id="job-123") + assert isinstance(uri, str) + assert "https://example.com/" in uri + assert "sas-token" in uri diff --git a/azure-quantum/tests/unit/local/test_workspace_connection_params_validation.py b/azure-quantum/tests/unit/local/test_workspace_connection_params_validation.py new file mode 100644 index 00000000..fa411473 --- /dev/null +++ b/azure-quantum/tests/unit/local/test_workspace_connection_params_validation.py @@ -0,0 +1,329 @@ +## +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +## + +import pytest +from azure.quantum._workspace_connection_params import WorkspaceConnectionParams + + +def test_valid_subscription_ids(): + """Test that valid subscription_ids are accepted.""" + valid_ids = [ + "12345678-1234-1234-1234-123456789abc", + "ABCDEF01-2345-6789-ABCD-EF0123456789", + "abcdef01-2345-6789-abcd-ef0123456789", + ] + for subscription_id in valid_ids: + params = WorkspaceConnectionParams(subscription_id=subscription_id) + assert params.subscription_id == subscription_id + + +def test_invalid_subscription_ids(): + """Test that invalid subscription_ids raise ValueError.""" + invalid_ids = [ + ("not-a-guid", "Subscription ID must be a valid GUID."), + (12345, "Subscription ID must be a string."), + ] + for subscription_id, expected_message in invalid_ids: + with pytest.raises(ValueError) as exc_info: + WorkspaceConnectionParams(subscription_id=subscription_id) + assert expected_message in str(exc_info.value) + + +def test_valid_resource_groups(): + """Test that valid resource_groups are accepted.""" + valid_groups = [ + "my-resource-group", + "MyResourceGroup", + "resource_group_123", + "rg123", + "a" * 90, # Max length (90 chars) + "a", # Min length (1 char) + "Resource_Group-1", + "my.resource.group", # Periods allowed (except at end) + "group(test)", # Parentheses allowed + "group(test)name", + "(parentheses)", + "test-group_name", + "GROUP-123", + "123-group", + "Test.Group.Name", + "my-group.v2", + "rg_test(prod)-v1.2", + "café", # Unicode letters (Lo) + "日本語", # Unicode letters (Lo) + "Казан", # Unicode letters (Lu, Ll) + "αβγ", # Greek letters (Ll) + "test-café-123", # Mixed ASCII and Unicode + "group_名前", # Mixed ASCII and Unicode + "test.group(1)-name_v2", # Multiple special chars + ] + for resource_group in valid_groups: + params = WorkspaceConnectionParams(resource_group=resource_group) + assert params.resource_group == resource_group + + +def test_invalid_resource_groups(): + """Test that invalid resource_groups raise ValueError.""" + rg_invalid_chars_msg = "Resource group name can only include alphanumeric, underscore, parentheses, hyphen, period (except at end), and Unicode characters that match the allowed characters." + invalid_groups = [ + ("my/resource/group", rg_invalid_chars_msg), + ("my\\resource\\group", rg_invalid_chars_msg), + ("my resource group", rg_invalid_chars_msg), + (12345, "Resource group name must be a string."), + ("group.", rg_invalid_chars_msg), # Period at end + ("my-group.", rg_invalid_chars_msg), # Period at end + ("test.group.", rg_invalid_chars_msg), # Period at end + ("a" * 91, "Resource group name must be between 1 and 90 characters long."), # Too long + ("group@test", rg_invalid_chars_msg), # @ symbol + ("group#test", rg_invalid_chars_msg), # # symbol + ("group$test", rg_invalid_chars_msg), # $ symbol + ("group%test", rg_invalid_chars_msg), # % symbol + ("group^test", rg_invalid_chars_msg), # ^ symbol + ("group&test", rg_invalid_chars_msg), # & symbol + ("group*test", rg_invalid_chars_msg), # * symbol + ("group+test", rg_invalid_chars_msg), # + symbol + ("group=test", rg_invalid_chars_msg), # = symbol + ("group[test]", rg_invalid_chars_msg), # Square brackets + ("group{test}", rg_invalid_chars_msg), # Curly brackets + ("group|test", rg_invalid_chars_msg), # Pipe + ("group:test", rg_invalid_chars_msg), # Colon + ("group;test", rg_invalid_chars_msg), # Semicolon + ("group\"test", rg_invalid_chars_msg), # Quote + ("group'test", rg_invalid_chars_msg), # Single quote + ("group", rg_invalid_chars_msg), # Angle brackets + ("group,test", rg_invalid_chars_msg), # Comma + ("group?test", rg_invalid_chars_msg), # Question mark + ("group!test", rg_invalid_chars_msg), # Exclamation mark + ("group`test", rg_invalid_chars_msg), # Backtick + ("group~test", rg_invalid_chars_msg), # Tilde + ("test\ngroup", rg_invalid_chars_msg), # Newline + ("test\tgroup", rg_invalid_chars_msg), # Tab + ] + for resource_group, expected_message in invalid_groups: + with pytest.raises(ValueError) as exc_info: + WorkspaceConnectionParams(resource_group=resource_group) + assert expected_message in str(exc_info.value) + + +def test_empty_resource_group(): + """Test that empty resource_group is treated as None (not set).""" + # Empty strings are treated as falsy in the merge logic and not set + params = WorkspaceConnectionParams(resource_group="") + assert params.resource_group is None + + +def test_valid_workspace_names(): + """Test that valid workspace names are accepted.""" + valid_names = [ + "12", + "a1", + "1a", + "ab", + "myworkspace", + "WORKSPACE", + "MyWorkspace", + "myWorkSpace", + "myworkspacE", + "1234567890", + "123workspace", + "workspace123", + "w0rksp4c3", + "123abc456def", + "abc123", + # with hyphens + "my-workspace", + "my-work-space", + "workspace-with-a-long-name-that-is-still-valid", + "a-b-c-d-e", + "my-workspace-2", + "workspace-1-2-3", + "1-a", + "b-2", + "1-2", + "a-b", + "1-b-2", + "a-1-b", + "workspace" + "-" * 10 + "test", + "a" * 54, # Max length (54 chars) + "1" * 54, # Max length with numbers + ] + for workspace_name in valid_names: + params = WorkspaceConnectionParams(workspace_name=workspace_name) + assert params.workspace_name == workspace_name + + +def test_invalid_workspace_names(): + """Test that invalid workspace names raise ValueError.""" + not_valid_names = [ + ("a", "Workspace name must be between 2 and 54 characters long."), + ("1", "Workspace name must be between 2 and 54 characters long."), + ("a" * 55, "Workspace name must be between 2 and 54 characters long."), + ("1" * 55, "Workspace name must be between 2 and 54 characters long."), + ("my_workspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("my/workspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("my workspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("-myworkspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("myworkspace-", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + (12345, "Workspace name must be a string."), + ] + for workspace_name, expected_message in not_valid_names: + with pytest.raises(ValueError) as exc_info: + WorkspaceConnectionParams(workspace_name=workspace_name) + assert expected_message in str(exc_info.value) + + +def test_empty_workspace_name(): + """Test that empty workspace_name is treated as None (not set).""" + # Empty strings are treated as falsy in the merge logic and not set + params = WorkspaceConnectionParams(workspace_name="") + assert params.workspace_name is None + + +def test_valid_locations(): + """Test that valid locations are accepted and normalized.""" + valid_locations = [ + ("East US", "eastus"), + ("West Europe", "westeurope"), + ("eastus", "eastus"), + ("westus2", "westus2"), + ("EASTUS", "eastus"), + ("WestUs2", "westus2"), + ("South Central US", "southcentralus"), + ("North Europe", "northeurope"), + ("Southeast Asia", "southeastasia"), + ("Japan East", "japaneast"), + ("UK South", "uksouth"), + ("Australia East", "australiaeast"), + ("Central India", "centralindia"), + ("France Central", "francecentral"), + ("Germany West Central", "germanywestcentral"), + ("Switzerland North", "switzerlandnorth"), + ("UAE North", "uaenorth"), + ("Brazil South", "brazilsouth"), + ("Korea Central", "koreacentral"), + ("South Africa North", "southafricanorth"), + ("Norway East", "norwayeast"), + ("Sweden Central", "swedencentral"), + ("Qatar Central", "qatarcentral"), + ("Poland Central", "polandcentral"), + ("Italy North", "italynorth"), + ("Israel Central", "israelcentral"), + ("Spain Central", "spaincentral"), + ("Austria East", "austriaeast"), + ("Belgium Central", "belgiumcentral"), + ("Chile Central", "chilecentral"), + ("Indonesia Central", "indonesiacentral"), + ("Malaysia West", "malaysiawest"), + ("Mexico Central", "mexicocentral"), + ("New Zealand North", "newzealandnorth"), + ("westus3", "westus3"), + ("canadacentral", "canadacentral"), + ("westcentralus", "westcentralus"), + ] + for location, expected in valid_locations: + params = WorkspaceConnectionParams(location=location) + assert params.location == expected + + +def test_invalid_locations(): + """Test that invalid locations raise ValueError.""" + location_invalid_region_msg = "Location must be one of the Azure regions listed in https://learn.microsoft.com/en-us/azure/reliability/regions-list." + invalid_locations = [ + (" ", location_invalid_region_msg), + ("invalid-region", location_invalid_region_msg), + ("us-east", location_invalid_region_msg), + ("east-us", location_invalid_region_msg), + ("westus4", location_invalid_region_msg), + ("southus", location_invalid_region_msg), + ("centraleurope", location_invalid_region_msg), + ("asiaeast", location_invalid_region_msg), + ("chinaeast", location_invalid_region_msg), + ("usgovtexas", location_invalid_region_msg), + ("East US 3", location_invalid_region_msg), + ("not a region", location_invalid_region_msg), + (12345, "Location must be a string."), + (3.14, "Location must be a string."), + (True, "Location must be a string."), + ] + for location, expected_message in invalid_locations: + with pytest.raises(ValueError) as exc_info: + WorkspaceConnectionParams(location=location) + assert expected_message in str(exc_info.value) + + +def test_empty_location(): + """Test that empty location is treated as None (not set).""" + # Empty strings are treated as falsy in the merge logic and not set + params = WorkspaceConnectionParams(location="") + assert params.location is None + + # None is also allowed and treated as not set + params = WorkspaceConnectionParams(location=None) + assert params.location is None + + +def test_none_values_are_allowed(): + """Test that None values for optional fields are allowed.""" + # This should not raise any exceptions + params = WorkspaceConnectionParams( + subscription_id=None, + resource_group=None, + workspace_name=None, + location=None, + user_agent=None, + ) + assert params.subscription_id is None + assert params.resource_group is None + assert params.workspace_name is None + assert params.location is None + assert params.user_agent is None + + +def test_multiple_valid_parameters(): + """Test that multiple valid parameters work together.""" + params = WorkspaceConnectionParams( + subscription_id="12345678-1234-1234-1234-123456789abc", + resource_group="my-resource-group", + workspace_name="my-workspace", + location="East US", + user_agent="my-app/1.0", + ) + assert params.subscription_id == "12345678-1234-1234-1234-123456789abc" + assert params.resource_group == "my-resource-group" + assert params.workspace_name == "my-workspace" + assert params.location == "eastus" + assert params.user_agent == "my-app/1.0" + + +def test_validation_on_resource_id(): + """Test that validation works when using resource_id.""" + # Valid resource_id should work + resource_id = ( + "/subscriptions/12345678-1234-1234-1234-123456789abc" + "/resourceGroups/my-rg" + "/providers/Microsoft.Quantum" + "/Workspaces/my-ws" + ) + params = WorkspaceConnectionParams(resource_id=resource_id) + assert params.subscription_id == "12345678-1234-1234-1234-123456789abc" + assert params.resource_group == "my-rg" + assert params.workspace_name == "my-ws" + + +def test_validation_on_connection_string(): + """Test that validation works when using connection_string.""" + # Valid connection string should work + connection_string = ( + "SubscriptionId=12345678-1234-1234-1234-123456789abc;" + "ResourceGroupName=my-rg;" + "WorkspaceName=my-ws;" + "ApiKey=test-key;" + "QuantumEndpoint=https://eastus.quantum.azure.com/;" + ) + params = WorkspaceConnectionParams(connection_string=connection_string) + assert params.subscription_id == "12345678-1234-1234-1234-123456789abc" + assert params.resource_group == "my-rg" + assert params.workspace_name == "my-ws" + assert params.location == "eastus" diff --git a/azure-quantum/tests/unit/test_job_results.py b/azure-quantum/tests/unit/test_job_results.py index 9ee57341..6caf49bc 100644 --- a/azure-quantum/tests/unit/test_job_results.py +++ b/azure-quantum/tests/unit/test_job_results.py @@ -9,7 +9,6 @@ import pytest from common import QuantumTestBase, RegexScrubbingPatterns from azure.quantum import Job, JobDetails -from azure.quantum.target import Target class TestJobResults(QuantumTestBase): @@ -18,15 +17,11 @@ class TestJobResults(QuantumTestBase): Tests the azure.quantum.job module. """ - def test_job_success(self): - job_results = self._get_job_results("test_output_data_format","{\"Histogram\": [\"[0]\", 0.50, \"[1]\", 0.50]}") - self.assertTrue(len(job_results["Histogram"]) == 4) - @pytest.mark.live_test @pytest.mark.xdist_group(name="echo-output") def test_job_get_results_with_expired_sas_token(self): """ - Get existing result blob url and replace its sas token with expired one, + Get existing result blob url and replace its sas token with expired one, so we can test its ability to refresh it. """ target = self.create_echo_target() @@ -38,49 +33,66 @@ def test_job_get_results_with_expired_sas_token(self): job.details.output_data_uri = re.sub( pattern=RegexScrubbingPatterns.URL_QUERY_SAS_KEY_EXPIRATION, repl="se=2024-01-01T00%3A00%3A00Z&", - string=job.details.output_data_uri) + string=job.details.output_data_uri, + ) job_results = job.get_results() self.assertEqual(job_results, input_data) - def test_job_for_microsoft_quantum_results_v1_success(self): - job_results = self._get_job_results("microsoft.quantum-results.v1","{\"Histogram\": [\"[0]\", 0.50, \"[1]\", 0.50]}") + job_results = self._get_job_results( + "microsoft.quantum-results.v1", '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}' + ) self.assertTrue(len(job_results.keys()) == 2) self.assertEqual(job_results["[0]"], 0.50) self.assertEqual(job_results["[1]"], 0.50) - def test_job_for_microsoft_quantum_results_v1_no_histogram_returns_raw_result(self): - job_result_raw = "{\"NotHistogramProperty\": [\"[0]\", 0.50, \"[1]\", 0.50]}" - job_result = self._get_job_results("microsoft.quantum-results.v1", job_result_raw) + job_result_raw = '{"NotHistogramProperty": ["[0]", 0.50, "[1]", 0.50]}' + job_result = self._get_job_results( + "microsoft.quantum-results.v1", job_result_raw + ) self.assertEqual(job_result, job_result_raw) - - def test_job_for_microsoft_quantum_results_v1_invalid_histogram_returns_raw_result(self): - job_result_raw = "{\"NotHistogramProperty\": [\"[0]\", 0.50, \"[1]\"]}" - job_result = self._get_job_results("microsoft.quantum-results.v1", job_result_raw) + def test_job_for_microsoft_quantum_results_v1_invalid_histogram_returns_raw_result( + self, + ): + job_result_raw = '{"NotHistogramProperty": ["[0]", 0.50, "[1]"]}' + job_result = self._get_job_results( + "microsoft.quantum-results.v1", job_result_raw + ) self.assertEqual(job_result, job_result_raw) def test_job_for_microsoft_quantum_results_v2_success(self): - job_results = self._get_job_results("microsoft.quantum-results.v2","{\"DataFormat\": \"microsoft.quantum-results.v2\", \"Results\": [{\"Histogram\": [{\"Outcome\": [0], \"Display\": \"[0]\", \"Count\": 2}, {\"Outcome\": [1], \"Display\": \"[1]\", \"Count\": 2}], \"Shots\": [[0], [1], [1], [0]]}]}") + job_results = self._get_job_results( + "microsoft.quantum-results.v2", + '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}', + ) self.assertTrue(len(job_results.keys()) == 2) self.assertEqual(job_results["[0]"], 0.50) self.assertEqual(job_results["[1]"], 0.50) def test_job_for_microsoft_quantum_results_v2_wrong_type_raises_exception(self): - job_result_raw = "{\"DataFormat\": \"microsoft.quantum-results.v1\", \"Results\": [{\"Histogram\": [{\"Outcome\": [0], \"Display\": \"[0]\", \"Count\": 2}, {\"Outcome\": [1], \"Display\": \"[1]\", \"Count\": 2}], \"Shots\": [[0], [1], [1], [0]]}]}" - job_result = self._get_job_results("microsoft.quantum-results.v2", job_result_raw) + job_result_raw = '{"DataFormat": "microsoft.quantum-results.v1", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}' + job_result = self._get_job_results( + "microsoft.quantum-results.v2", job_result_raw + ) self.assertEqual(job_result, job_result_raw) - - def test_job_for_microsoft_quantum_results_v2_invalid_histogram_returns_raw_result(self): - job_result_raw = "{\"DataFormat\": \"microsoft.quantum-results.v2\", \"Results\": [{\"Histogram\": [{\"Outcome\": [0], \"Display\": \"[0]\"}, {\"Outcome\": [1], \"Display\": \"[1]\", \"Count\": 2}], \"Shots\": [[0], [1], [1], [0]]}]}" - job_result = self._get_job_results("microsoft.quantum-results.v2", job_result_raw) + def test_job_for_microsoft_quantum_results_v2_invalid_histogram_returns_raw_result( + self, + ): + job_result_raw = '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]"}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}' + job_result = self._get_job_results( + "microsoft.quantum-results.v2", job_result_raw + ) self.assertEqual(job_result, job_result_raw) def test_job_for_microsoft_quantum_results_histogram_v2_success(self): - job_results = self._get_job_results_histogram("microsoft.quantum-results.v2","{\"DataFormat\": \"microsoft.quantum-results.v2\", \"Results\": [{\"Histogram\": [{\"Outcome\": [0], \"Display\": \"[0]\", \"Count\": 2}, {\"Outcome\": [1], \"Display\": \"[1]\", \"Count\": 2}], \"Shots\": [[0], [1], [1], [0]]}]}") + job_results = self._get_job_results_histogram( + "microsoft.quantum-results.v2", + '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}', + ) self.assertTrue(len(job_results.keys()) == 2) self.assertEqual(job_results["[0]"]["count"], 2) self.assertEqual(job_results["[1]"]["count"], 2) @@ -88,7 +100,10 @@ def test_job_for_microsoft_quantum_results_histogram_v2_success(self): self.assertEqual(job_results["[1]"]["outcome"], [1]) def test_job_for_microsoft_quantum_results_histogram_batch_v2_success(self): - job_results = self._get_job_results_histogram("microsoft.quantum-results.v2","{\"DataFormat\": \"microsoft.quantum-results.v2\", \"Results\": [{\"Histogram\": [{\"Outcome\": [0], \"Display\": \"[0]\", \"Count\": 2}, {\"Outcome\": [1], \"Display\": \"[1]\", \"Count\": 2}], \"Shots\": [[0], [1], [1], [0]]}, {\"Histogram\": [{\"Outcome\": [0], \"Display\": \"[0]\", \"Count\": 2}, {\"Outcome\": [1], \"Display\": \"[1]\", \"Count\": 2}], \"Shots\": [[0], [1], [1], [0]]}, {\"Histogram\": [{\"Outcome\": [0], \"Display\": \"[0]\", \"Count\": 2}, {\"Outcome\": [1], \"Display\": \"[1]\", \"Count\": 2}], \"Shots\": [[0], [1], [1], [0]]}]}") + job_results = self._get_job_results_histogram( + "microsoft.quantum-results.v2", + '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}, {"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}, {"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}', + ) self.assertTrue(len(job_results) == 3) for result in job_results: self.assertTrue(len(result.keys()) == 2) @@ -97,24 +112,35 @@ def test_job_for_microsoft_quantum_results_histogram_batch_v2_success(self): self.assertEqual(result["[0]"]["outcome"], [0]) self.assertEqual(result["[1]"]["outcome"], [1]) - def test_job_for_microsoft_quantum_results_histogram_v2_wrong_type_raises_exception(self): + def test_job_for_microsoft_quantum_results_histogram_v2_wrong_type_raises_exception( + self, + ): try: - job_results = self._get_job_results_histogram("microsoft.quantum-results.v2","{\"Histogram\": [\"[0]\", 0.50, \"[1]\", 0.50]}") + job_results = self._get_job_results_histogram( + "microsoft.quantum-results.v2", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + ) # Fail test because we didn't get the error self.assertTrue(False) except: self.assertTrue(True) def test_job_for_microsoft_quantum_results_shots_v2_success(self): - job_results = self._get_job_results_shots("microsoft.quantum-results.v2","{\"DataFormat\": \"microsoft.quantum-results.v2\", \"Results\": [{\"Histogram\": [{\"Outcome\": [0], \"Display\": \"[0]\", \"Count\": 2}, {\"Outcome\": [1], \"Display\": \"[1]\", \"Count\": 2}], \"Shots\": [[0], [1], [1], [0]]}]}") + job_results = self._get_job_results_shots( + "microsoft.quantum-results.v2", + '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}', + ) self.assertTrue(len(job_results) == 4) self.assertEqual(job_results[0], [0]) self.assertEqual(job_results[1], [1]) self.assertEqual(job_results[2], [1]) self.assertEqual(job_results[3], [0]) - + def test_job_for_microsoft_quantum_results_shots_batch_v2_success(self): - job_results = self._get_job_results_shots("microsoft.quantum-results.v2","{\"DataFormat\": \"microsoft.quantum-results.v2\", \"Results\": [{\"Histogram\": [{\"Outcome\": [0], \"Display\": \"[0]\", \"Count\": 2}, {\"Outcome\": [1], \"Display\": \"[1]\", \"Count\": 2}], \"Shots\": [[0], [1], [1], [0]]}, {\"Histogram\": [{\"Outcome\": [0], \"Display\": \"[0]\", \"Count\": 2}, {\"Outcome\": [1], \"Display\": \"[1]\", \"Count\": 2}], \"Shots\": [[0], [1], [1], [0]]}, {\"Histogram\": [{\"Outcome\": [0], \"Display\": \"[0]\", \"Count\": 2}, {\"Outcome\": [1], \"Display\": \"[1]\", \"Count\": 2}], \"Shots\": [[0], [1], [1], [0]]}]}") + job_results = self._get_job_results_shots( + "microsoft.quantum-results.v2", + '{"DataFormat": "microsoft.quantum-results.v2", "Results": [{"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}, {"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}, {"Histogram": [{"Outcome": [0], "Display": "[0]", "Count": 2}, {"Outcome": [1], "Display": "[1]", "Count": 2}], "Shots": [[0], [1], [1], [0]]}]}', + ) self.assertTrue(len(job_results) == 3) for i in range(3): self.assertTrue(len(job_results[i]) == 4) @@ -124,7 +150,7 @@ def test_job_for_microsoft_quantum_results_shots_batch_v2_success(self): self.assertEqual(job_results[i][3], [0]) def test_job_for_microsoft_quantum_results_histogram_v2_tuple_success(self): - output = '''{ + output = """{ \"DataFormat\": \"microsoft.quantum-results.v2\", \"Results\": [ { @@ -184,19 +210,26 @@ def test_job_for_microsoft_quantum_results_histogram_v2_tuple_success(self): ] } ] -}''' - job_results = self._get_job_results_histogram("microsoft.quantum-results.v2", output) - +}""" + job_results = self._get_job_results_histogram( + "microsoft.quantum-results.v2", output + ) + self.assertTrue(len(job_results.keys()) == 3) self.assertEqual(job_results["[1, 0]"]["count"], 1) self.assertEqual(job_results["[1]"]["count"], 1) - self.assertEqual(job_results["([1, 0], (-2.71, 67), [(6, true), (12, false)])"]["count"], 1) - self.assertEqual(job_results["([1, 0], (-2.71, 67), [(6, true), (12, false)])"]["outcome"], ([1, 0], (-2.71, 67), [(6, True), (12, False)])) + self.assertEqual( + job_results["([1, 0], (-2.71, 67), [(6, true), (12, false)])"]["count"], 1 + ) + self.assertEqual( + job_results["([1, 0], (-2.71, 67), [(6, true), (12, false)])"]["outcome"], + ([1, 0], (-2.71, 67), [(6, True), (12, False)]), + ) self.assertEqual(job_results["[1]"]["outcome"], [1]) self.assertEqual(job_results["[1, 0]"]["outcome"], [1, 0]) def test_job_for_microsoft_quantum_results_shots_v2_tuple_success(self): - output = '''{ + output = """{ \"DataFormat\": \"microsoft.quantum-results.v2\", \"Results\": [ { @@ -242,17 +275,24 @@ def test_job_for_microsoft_quantum_results_shots_v2_tuple_success(self): ] } ] - }''' - job_results = self._get_job_results_shots("microsoft.quantum-results.v2", output) + }""" + job_results = self._get_job_results_shots( + "microsoft.quantum-results.v2", output + ) self.assertTrue(len(job_results) == 3) self.assertEqual(job_results[0], ([1, 0], (-2.71, 67))) self.assertEqual(job_results[1], [1, 0]) self.assertEqual(job_results[2], [1]) - def test_job_for_microsoft_quantum_results_shots_v2_wrong_type_raises_exception(self): + def test_job_for_microsoft_quantum_results_shots_v2_wrong_type_raises_exception( + self, + ): try: - job_results = self._get_job_results_shots("microsoft.quantum-results.v2","{\"Histogram\": [\"[0]\", 0.50, \"[1]\", 0.50]}") + job_results = self._get_job_results_shots( + "microsoft.quantum-results.v2", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + ) # Fail test because we didn't get the error self.assertTrue(False) except: @@ -260,14 +300,22 @@ def test_job_for_microsoft_quantum_results_shots_v2_wrong_type_raises_exception( def test_job_get_results_with_succeeded_status(self): """Test that get_results works correctly when job status is 'Succeeded'""" - job_results = self._get_job_results_with_status("Succeeded", "microsoft.quantum-results.v1", "{\"Histogram\": [\"[0]\", 0.50, \"[1]\", 0.50]}") + job_results = self._get_job_results_with_status( + "Succeeded", + "microsoft.quantum-results.v1", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + ) self.assertTrue(len(job_results.keys()) == 2) self.assertEqual(job_results["[0]"], 0.50) self.assertEqual(job_results["[1]"], 0.50) def test_job_get_results_with_completed_status(self): """Test that get_results works correctly when job status is 'Completed'""" - job_results = self._get_job_results_with_status("Completed", "microsoft.quantum-results.v1", "{\"Histogram\": [\"[0]\", 0.50, \"[1]\", 0.50]}") + job_results = self._get_job_results_with_status( + "Completed", + "microsoft.quantum-results.v1", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + ) self.assertTrue(len(job_results.keys()) == 2) self.assertEqual(job_results["[0]"], 0.50) self.assertEqual(job_results["[1]"], 0.50) @@ -275,57 +323,72 @@ def test_job_get_results_with_completed_status(self): def test_job_get_results_with_failed_status_raises_runtime_error(self): """Test that get_results raises RuntimeError when job status is 'Failed'""" with self.assertRaises(RuntimeError) as context: - self._get_job_results_with_status("Failed", "microsoft.quantum-results.v1", "{\"Histogram\": [\"[0]\", 0.50, \"[1]\", 0.50]}") - self.assertIn("Cannot retrieve results as job execution failed", str(context.exception)) + self._get_job_results_with_status( + "Failed", + "microsoft.quantum-results.v1", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + ) + self.assertIn( + "Cannot retrieve results as job execution failed", str(context.exception) + ) self.assertIn("FAILED", str(context.exception)) def test_job_get_results_with_cancelled_status_raises_runtime_error(self): """Test that get_results raises RuntimeError when job status is 'Cancelled'""" with self.assertRaises(RuntimeError) as context: - self._get_job_results_with_status("Cancelled", "microsoft.quantum-results.v1", "{\"Histogram\": [\"[0]\", 0.50, \"[1]\", 0.50]}") - self.assertIn("Cannot retrieve results as job execution failed", str(context.exception)) + self._get_job_results_with_status( + "Cancelled", + "microsoft.quantum-results.v1", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + ) + self.assertIn( + "Cannot retrieve results as job execution failed", str(context.exception) + ) self.assertIn("CANCELLED", str(context.exception)) def _get_job_results(self, output_data_format, results_as_json_str): job = self._mock_job(output_data_format, results_as_json_str) - + return job.get_results() - - def _get_job_results_with_status(self, status, output_data_format, results_as_json_str): + + def _get_job_results_with_status( + self, status, output_data_format, results_as_json_str + ): job = self._mock_job(output_data_format, results_as_json_str) job.details.status = status - + return job.get_results() - + def _get_job_results_histogram(self, output_data_format, results_as_json_str): job = self._mock_job(output_data_format, results_as_json_str) return job.get_results_histogram() - + def _get_job_results_shots(self, output_data_format, results_as_json_str): job = self._mock_job(output_data_format, results_as_json_str) - + return job.get_results_shots() - + def _mock_job(self, output_data_format, results_as_json_str): job_details = JobDetails( - id= "", - name= "", + id="", + name="", provider_id="", target="", container_uri="", input_data_format="", - output_data_format = output_data_format) + output_data_format=output_data_format, + ) job_details.status = "Succeeded" - job = Job( - workspace=None, - job_details=job_details) - + job = Job(workspace=None, job_details=job_details) + job.has_completed = Mock(return_value=True) job.wait_until_completed = Mock() class DowloadDataMock(object): - def decode(): str + def decode(): + str + pass download_data = DowloadDataMock()