diff --git a/src/uipath_langchain/runtime/runtime.py b/src/uipath_langchain/runtime/runtime.py index a51de93d6..8538e02ed 100644 --- a/src/uipath_langchain/runtime/runtime.py +++ b/src/uipath_langchain/runtime/runtime.py @@ -1,3 +1,4 @@ +import contextvars import logging import os from collections.abc import Iterator @@ -18,6 +19,7 @@ UiPathRuntimeStorageProtocol, UiPathStreamOptions, ) +from uipath.tracing import ReferenceContext, ReferenceContextAccessor from uipath.runtime.errors import ( UiPathBaseRuntimeError, UiPathErrorCategory, @@ -75,12 +77,31 @@ def __init__( self.chat.client_side_tools = self._get_client_side_tools() self._middleware_node_names: set[str] = self._detect_middleware_nodes() + def _push_reference_context(self) -> contextvars.Token: + """Append this runtime's own entry to the ambient ReferenceContext. + + Reads any parent context already in the accessor (e.g. set by an + upstream middleware or the agents-python runtime), then appends a + ``langgraph`` entry for this runtime. Returns the ContextVar token + so the caller can reset in a ``finally`` block. + """ + agent_id = os.environ.get("UIPATH_AGENT_ID") + agent_version = os.environ.get("UIPATH_PROCESS_VERSION") or None + parent_ctx = ReferenceContextAccessor.get() or ReferenceContext.Empty + ref_ctx = ( + parent_ctx.add("langgraph", agent_id, agent_version) + if agent_id + else parent_ctx + ) + return ReferenceContextAccessor.set(ref_ctx) + async def execute( self, input: dict[str, Any] | None = None, options: UiPathExecuteOptions | None = None, ) -> UiPathRuntimeResult: """Execute the graph with the provided input and configuration.""" + ref_ctx_token = self._push_reference_context() try: graph_input = await self._get_graph_input(input, options) graph_config = self._get_graph_config() @@ -99,6 +120,8 @@ async def execute( except Exception as e: raise self.create_runtime_error(e) from e + finally: + ReferenceContextAccessor.reset(ref_ctx_token) async def stream( self, @@ -133,6 +156,7 @@ async def stream( Raises: LangGraphRuntimeError: If execution fails """ + ref_ctx_token = self._push_reference_context() try: graph_input = await self._get_graph_input(input, options) graph_config = self._get_graph_config() @@ -230,6 +254,8 @@ async def stream( except Exception as e: raise self.create_runtime_error(e) from e + finally: + ReferenceContextAccessor.reset(ref_ctx_token) async def get_schema(self) -> UiPathRuntimeSchema: """Get schema for this LangGraph runtime.""" diff --git a/tests/runtime/test_reference_context_wiring.py b/tests/runtime/test_reference_context_wiring.py new file mode 100644 index 000000000..e944ca625 --- /dev/null +++ b/tests/runtime/test_reference_context_wiring.py @@ -0,0 +1,205 @@ +"""Tests for ReferenceContext wiring in UiPathLangGraphRuntime.""" + +import os +import tempfile +from typing import Any, TypedDict + +import pytest +from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver +from langgraph.graph import END, START, StateGraph + +from uipath.platform.common._reference_context import ( + ReferenceContext, + ReferenceContextAccessor, +) +from uipath_langchain.runtime.runtime import UiPathLangGraphRuntime + + +# --------------------------------------------------------------------------- +# Minimal graph fixture +# --------------------------------------------------------------------------- + +class _State(TypedDict): + value: str + + +def _build_graph() -> Any: + graph = StateGraph(_State) + graph.add_node("step", lambda s: {"value": s.get("value", "") + "_done"}) + graph.add_edge(START, "step") + graph.add_edge("step", END) + return graph + + +def _clear_accessor() -> None: + token = ReferenceContextAccessor.set(None) + ReferenceContextAccessor.reset(token) + + +# --------------------------------------------------------------------------- +# _push_reference_context — unit tests (no graph needed) +# --------------------------------------------------------------------------- + +class TestPushReferenceContext: + def setup_method(self) -> None: + _clear_accessor() + + def teardown_method(self) -> None: + _clear_accessor() + + def test_sets_langgraph_entry_when_agent_id_present( + self, monkeypatch: pytest.MonkeyPatch, tmp_path + ) -> None: + monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020") + monkeypatch.delenv("UIPATH_PROCESS_VERSION", raising=False) + + from langgraph.graph import StateGraph + graph = _build_graph().compile() + runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="t") + + token = runtime._push_reference_context() + try: + ctx = ReferenceContextAccessor.get() + assert ctx is not None + assert len(ctx) == 1 + assert ctx.entries[0].service_type == "langgraph" + assert ctx.entries[0].reference_id == "550e8400-e29b-41d4-a716-446655440020" + assert ctx.entries[0].version is None + finally: + ReferenceContextAccessor.reset(token) + + def test_includes_version_when_env_set( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020") + monkeypatch.setenv("UIPATH_PROCESS_VERSION", "3.1.0") + + graph = _build_graph().compile() + runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="t") + + token = runtime._push_reference_context() + try: + ctx = ReferenceContextAccessor.get() + assert ctx is not None + assert ctx.entries[0].version == "3.1.0" + finally: + ReferenceContextAccessor.reset(token) + + def test_no_entry_when_agent_id_absent( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("UIPATH_AGENT_ID", raising=False) + monkeypatch.delenv("UIPATH_PROCESS_VERSION", raising=False) + + graph = _build_graph().compile() + runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="t") + + token = runtime._push_reference_context() + try: + ctx = ReferenceContextAccessor.get() + assert ctx is not None + assert len(ctx) == 0 + finally: + ReferenceContextAccessor.reset(token) + + def test_stacks_on_top_of_parent_context( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020") + monkeypatch.delenv("UIPATH_PROCESS_VERSION", raising=False) + + parent = ReferenceContext.Empty.add( + "agent", "550e8400-e29b-41d4-a716-446655440001", "1.0" + ) + parent_token = ReferenceContextAccessor.set(parent) + + graph = _build_graph().compile() + runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="t") + + token = runtime._push_reference_context() + try: + ctx = ReferenceContextAccessor.get() + assert ctx is not None + assert len(ctx) == 2 + assert ctx.entries[0].service_type == "agent" + assert ctx.entries[1].service_type == "langgraph" + finally: + ReferenceContextAccessor.reset(token) + ReferenceContextAccessor.reset(parent_token) + + +# --------------------------------------------------------------------------- +# execute() — context cleared after run +# --------------------------------------------------------------------------- + +async def test_context_cleared_after_execute( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + _clear_accessor() + monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020") + monkeypatch.delenv("UIPATH_PROCESS_VERSION", raising=False) + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db = f.name + + async with AsyncSqliteSaver.from_conn_string(db) as memory: + await memory.setup() + graph = _build_graph().compile(checkpointer=memory) + runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="exec-run") + await runtime.execute(input={"value": "hello"}) + + assert ReferenceContextAccessor.get() is None + + +async def test_context_cleared_after_execute_on_error( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + _clear_accessor() + monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020") + + class _S(TypedDict): + v: str + + def _boom(s: _S) -> _S: + raise ValueError("explode") + + g = StateGraph(_S) + g.add_node("boom", _boom) + g.add_edge(START, "boom") + g.add_edge("boom", END) + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db = f.name + + async with AsyncSqliteSaver.from_conn_string(db) as memory: + await memory.setup() + compiled = g.compile(checkpointer=memory) + runtime = UiPathLangGraphRuntime(graph=compiled, runtime_id="err-run") + with pytest.raises(Exception): + await runtime.execute(input={"v": "x"}) + + assert ReferenceContextAccessor.get() is None + + +# --------------------------------------------------------------------------- +# stream() — context cleared after run +# --------------------------------------------------------------------------- + +async def test_context_cleared_after_stream( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + _clear_accessor() + monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020") + monkeypatch.delenv("UIPATH_PROCESS_VERSION", raising=False) + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db = f.name + + async with AsyncSqliteSaver.from_conn_string(db) as memory: + await memory.setup() + graph = _build_graph().compile(checkpointer=memory) + runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="stream-run") + async for _ in runtime.stream(input={"value": "hi"}): + pass + + assert ReferenceContextAccessor.get() is None