Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 56 additions & 20 deletions sdks/python/src/opik/integrations/langchain/opik_tracer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import logging
import datetime
from typing import Any, Dict, List, Literal, Optional, Set, TYPE_CHECKING, cast, Tuple
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Set,
TYPE_CHECKING,
cast,
Tuple,
Callable,
)
import contextvars
from uuid import UUID

Expand Down Expand Up @@ -36,6 +47,11 @@

SpanType = Literal["llm", "tool", "general"]

SkipErrorCallback = Callable[[str], bool]


ERROR_SKIPPED_OUTPUTS = {"warning": "Error output skipped by skip_error_callback."}


def _get_span_type(run: Dict[str, Any]) -> SpanType:
if run.get("run_type") in ["llm", "tool"]:
Expand Down Expand Up @@ -65,6 +81,7 @@ def __init__(
project_name: Optional[str] = None,
distributed_headers: Optional[DistributedTraceHeadersDict] = None,
thread_id: Optional[str] = None,
skip_error_callback: Optional[SkipErrorCallback] = None,
**kwargs: Any,
) -> None:
validator = parameters_validator.create_validator(
Expand Down Expand Up @@ -113,6 +130,8 @@ def __init__(
Optional[str]
] = contextvars.ContextVar("root_run_external_parent_span_id", default=None)

self._skip_error_callback = skip_error_callback

def _is_opik_span_created_by_this_tracer(self, span_id: str) -> bool:
return any(span_.id == span_id for span_ in self._span_data_map.values())

Expand All @@ -135,17 +154,23 @@ def _persist_run(self, run: Run) -> None:
error_info: Optional[ErrorInfoDict]
trace_additional_metadata: Dict[str, Any] = {}

if run_dict["error"] is not None:
output = None
error_info = ErrorInfoDict(
exception_type="Exception",
traceback=run_dict["error"],
)
else:
output, trace_additional_metadata = (
langchain_helpers.split_big_langgraph_outputs(run_dict["outputs"])
error_str = run_dict.get("error")
outputs = run_dict.get("outputs")
error_info = None

if error_str is not None:
outputs = None
if not self._should_skip_error(error_str):
error_info = ErrorInfoDict(
exception_type="Exception",
traceback=error_str,
)
else:
outputs = ERROR_SKIPPED_OUTPUTS
elif outputs is not None:
outputs, trace_additional_metadata = (
langchain_helpers.split_big_langgraph_outputs(outputs)
)
error_info = None

if (
span_data.parent_span_id is not None
Expand All @@ -169,7 +194,7 @@ def _persist_run(self, run: Run) -> None:
if trace_additional_metadata:
trace_data.update(metadata=trace_additional_metadata)

trace_data.init_end_time().update(output=output, error_info=error_info)
trace_data.init_end_time().update(output=outputs, error_info=error_info)
trace_ = self._opik_client.trace(**trace_data.as_parameters)

assert trace_ is not None
Expand Down Expand Up @@ -446,6 +471,12 @@ def _process_end_span(self, run: Run) -> None:
)
self._opik_context_storage.pop_span_data(ensure_id=span_data.id)

def _should_skip_error(self, error_str: str) -> bool:
if self._skip_error_callback is None:
return False

return self._skip_error_callback(error_str)

def _process_end_span_with_error(self, run: Run) -> None:
if run.id not in self._span_data_map:
LOGGER.warning(
Expand All @@ -457,15 +488,20 @@ def _process_end_span_with_error(self, run: Run) -> None:
try:
run_dict: Dict[str, Any] = run.dict()
span_data = self._span_data_map[run.id]
error_info: ErrorInfoDict = {
"exception_type": "Exception",
"traceback": run_dict["error"],
}
error_str = run_dict["error"]

if self._should_skip_error(error_str):
span_data.init_end_time().update(output=ERROR_SKIPPED_OUTPUTS)
else:
error_info = ErrorInfoDict(
exception_type="Exception",
traceback=error_str,
)
span_data.init_end_time().update(
output=None,
error_info=error_info,
)

span_data.init_end_time().update(
output=None,
error_info=error_info,
)
if tracing_runtime_config.is_tracing_active():
self._opik_client.span(**span_data.as_parameters)
except Exception as e:
Expand Down
100 changes: 99 additions & 1 deletion sdks/python/tests/library_integration/langchain/test_langchain.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import pytest
from langchain_core.language_models import fake
from langchain_core.language_models.fake import FakeStreamingListLLM
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableConfig

import opik
from opik import context_storage
from opik.api_objects import opik_client, span, trace
from opik.config import OPIK_PROJECT_DEFAULT_NAME
from opik.integrations.langchain.opik_tracer import OpikTracer
from opik.integrations.langchain.opik_tracer import OpikTracer, ERROR_SKIPPED_OUTPUTS
from opik.types import DistributedTraceHeadersDict

from ...testlib import (
ANY_BUT_NONE,
ANY_DICT,
Expand Down Expand Up @@ -653,3 +656,98 @@ def test_langchain_callback__disabled_tracking(fake_backend):

assert len(fake_backend.trace_trees) == 0
assert len(callback.created_traces()) == 0


def test_langchain_callback__skip_error_callback__error_output_skipped(
fake_backend,
):
def _should_skip_error(error: str) -> bool:
if error is not None and error.startswith("FakeListLLMError"):
# skip processing - we are sure that this is OK
return True
else:
return False

callback = OpikTracer(
skip_error_callback=_should_skip_error,
)

llm = FakeStreamingListLLM(
error_on_chunk_number=0, # throw error on the first chunk
responses=["I'm sorry, I don't think I'm talented enough to write a synopsis"],
)

template = "Given the title of play, write a synopsys for that. Title: {title}."
prompt_template = PromptTemplate(input_variables=["title"], template=template)

synopsis_chain = prompt_template | llm
test_prompts = {"title": "Documentary about Bigfoot in Paris"}

stream = synopsis_chain.stream(
input=test_prompts, config=RunnableConfig(callbacks=[callback])
)
try:
for p in stream:
print(p)
except Exception:
# ignoring exception
pass

opik.flush_tracker()

assert len(fake_backend.trace_trees) == 1

EXPECTED_TRACE_TREE = TraceModel(
id=ANY_BUT_NONE,
start_time=ANY_BUT_NONE,
name="RunnableSequence",
project_name="Default Project",
input={"title": "Documentary about Bigfoot in Paris"},
output=ERROR_SKIPPED_OUTPUTS,
metadata={"created_from": "langchain"},
end_time=ANY_BUT_NONE,
spans=[
SpanModel(
id=ANY_BUT_NONE,
start_time=ANY_BUT_NONE,
name="RunnableSequence",
input={"input": ""},
output=ERROR_SKIPPED_OUTPUTS,
metadata={"created_from": "langchain"},
type="general",
end_time=ANY_BUT_NONE,
project_name="Default Project",
spans=[
SpanModel(
id=ANY_BUT_NONE,
start_time=ANY_BUT_NONE,
name="PromptTemplate",
input={"title": "Documentary about Bigfoot in Paris"},
output={"output": ANY_DICT},
metadata={"created_from": "langchain"},
type="tool",
end_time=ANY_BUT_NONE,
project_name="Default Project",
last_updated_at=ANY_BUT_NONE,
),
SpanModel(
id=ANY_BUT_NONE,
start_time=ANY_BUT_NONE,
name="FakeStreamingListLLM",
input={"prompts": ANY_BUT_NONE},
output=ANY_DICT,
tags=None,
metadata=ANY_DICT,
type="llm",
end_time=ANY_BUT_NONE,
project_name="Default Project",
last_updated_at=ANY_BUT_NONE,
),
],
last_updated_at=ANY_BUT_NONE,
)
],
last_updated_at=ANY_BUT_NONE,
)

assert_equal(expected=EXPECTED_TRACE_TREE, actual=fake_backend.trace_trees[0])
Loading