Skip to content

Commit

Permalink
Better wrap traceable (langchain-ai#12303)
Browse files Browse the repository at this point in the history
If user function is wrapped as a traceable function, this will help hand
off the trace between the two.

Also update handling fields to reflect optional values
  • Loading branch information
hinthornw authored Oct 25, 2023
1 parent 5a71b81 commit 1d568e1
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 47 deletions.
6 changes: 4 additions & 2 deletions libs/langchain/langchain/callbacks/tracers/log_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _on_run_create(self, run: Run) -> None:
name=run.name,
type=run.run_type,
tags=run.tags or [],
metadata=run.extra.get("metadata", {}),
metadata=(run.extra or {}).get("metadata", {}),
start_time=run.start_time.isoformat(timespec="milliseconds"),
streamed_output_str=[],
final_output=None,
Expand Down Expand Up @@ -266,7 +266,9 @@ def _on_run_update(self, run: Run) -> None:
{
"op": "add",
"path": f"/logs/{index}/end_time",
"value": run.end_time.isoformat(timespec="milliseconds"),
"value": run.end_time.isoformat(timespec="milliseconds")
if run.end_time is not None
else None,
},
)
)
Expand Down
7 changes: 5 additions & 2 deletions libs/langchain/langchain/callbacks/tracers/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import datetime
import warnings
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type
from uuid import UUID

from langsmith.schemas import RunBase as BaseRunV2
Expand All @@ -13,7 +13,7 @@
from langchain.schema import LLMResult


def RunTypeEnum() -> RunTypeEnumDep:
def RunTypeEnum() -> Type[RunTypeEnumDep]:
"""RunTypeEnum."""
warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead"
Expand Down Expand Up @@ -106,6 +106,7 @@ class Run(BaseRunV2):
child_execution_order: int
child_runs: List[Run] = Field(default_factory=list)
tags: Optional[List[str]] = Field(default_factory=list)
events: List[Dict[str, Any]] = Field(default_factory=list)

@root_validator(pre=True)
def assign_name(cls, values: dict) -> dict:
Expand All @@ -115,6 +116,8 @@ def assign_name(cls, values: dict) -> dict:
values["name"] = values["serialized"]["name"]
elif "id" in values["serialized"]:
values["name"] = values["serialized"]["id"][-1]
if values.get("events") is None:
values["events"] = []
return values


Expand Down
10 changes: 7 additions & 3 deletions libs/langchain/langchain/callbacks/tracers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
PRINT_WARNINGS = True


def _serialize_io(run_inputs: dict) -> dict:
def _serialize_io(run_inputs: Optional[dict]) -> dict:
if not run_inputs:
return {}
from google.protobuf.json_format import MessageToJson
from google.protobuf.message import Message

Expand Down Expand Up @@ -79,7 +81,9 @@ def _convert_run_to_wb_span(self, run: Run) -> "Span":
span_id=str(run.id) if run.id is not None else None,
name=run.name,
start_time_ms=int(run.start_time.timestamp() * 1000),
end_time_ms=int(run.end_time.timestamp() * 1000),
end_time_ms=int(run.end_time.timestamp() * 1000)
if run.end_time is not None
else None,
status_code=self.trace_tree.StatusCode.SUCCESS
if run.error is None
else self.trace_tree.StatusCode.ERROR,
Expand All @@ -95,7 +99,7 @@ def _convert_llm_run_to_wb_span(self, run: Run) -> "Span":
base_span = self._convert_run_to_wb_span(run)
if base_span.attributes is None:
base_span.attributes = {}
base_span.attributes["llm_output"] = run.outputs.get("llm_output", {})
base_span.attributes["llm_output"] = (run.outputs or {}).get("llm_output", {})

base_span.results = [
self.trace_tree.Result(
Expand Down
8 changes: 5 additions & 3 deletions libs/langchain/langchain/chat_loaders/langsmith.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union, cast

from langchain.chat_loaders.base import BaseChatLoader
from langchain.load import load
Expand Down Expand Up @@ -79,7 +79,7 @@ def _get_functions_from_llm_run(llm_run: "Run") -> Optional[List[Dict]]:
"""
if llm_run.run_type != "llm":
raise ValueError(f"Expected run of type llm. Got: {llm_run.run_type}")
return llm_run.extra.get("invocation_params", {}).get("functions")
return (llm_run.extra or {}).get("invocation_params", {}).get("functions")

def lazy_load(self) -> Iterator[ChatSession]:
"""
Expand All @@ -90,13 +90,15 @@ def lazy_load(self) -> Iterator[ChatSession]:
:return: Iterator of chat sessions containing messages.
"""
from langsmith.schemas import Run

for run_obj in self.runs:
try:
if hasattr(run_obj, "id"):
run = run_obj
else:
run = self.client.read_run(run_obj)
session = self._load_single_chat_session(run)
session = self._load_single_chat_session(cast(Run, run))
yield session
except ValueError as e:
logger.warning(f"Could not load run {run_obj}: {repr(e)}")
Expand Down
15 changes: 12 additions & 3 deletions libs/langchain/langchain/smith/evaluation/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
cast,
)

from langsmith import Client, RunEvaluator
from langsmith.client import Client
from langsmith.evaluation import RunEvaluator
from langsmith.run_helpers import as_runnable, is_traceable_function
from langsmith.schemas import Dataset, DataType, Example

from langchain._api import warn_deprecated
Expand Down Expand Up @@ -152,6 +154,9 @@ def _wrap_in_chain_factory(
lcf = llm_or_chain_factory
return lambda: lcf
elif callable(llm_or_chain_factory):
if is_traceable_function(llm_or_chain_factory):
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory))
return lambda: runnable_
try:
_model = llm_or_chain_factory() # type: ignore[call-arg]
except TypeError:
Expand All @@ -166,6 +171,9 @@ def _wrap_in_chain_factory(
# It's not uncommon to do an LLM constructor instead of raw LLM,
# so we'll unpack it for the user.
return _model
elif is_traceable_function(cast(Callable, _model)):
runnable_ = as_runnable(cast(Callable, _model))
return lambda: runnable_
elif not isinstance(_model, Runnable):
# This is unlikely to happen - a constructor for a model function
return lambda: RunnableLambda(constructor)
Expand Down Expand Up @@ -879,7 +887,8 @@ def _prepare_eval_run(
f"Project {project_name} already exists. Please use a different name."
)
print(
f"View the evaluation results for project '{project_name}' at:\n{project.url}",
f"View the evaluation results for project '{project_name}'"
f" at:\n{project.url}?eval=true",
flush=True,
)
examples = list(client.list_examples(dataset_id=dataset.id))
Expand Down Expand Up @@ -909,7 +918,7 @@ def _prepare_run_on_dataset(
)
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
run_evaluators = _setup_evaluation(
wrapped_model, examples, evaluation, dataset.data_type
wrapped_model, examples, evaluation, dataset.data_type or DataType.kv
)
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
progress_bar = progress.ProgressBarCallback(len(examples))
Expand Down
34 changes: 4 additions & 30 deletions libs/langchain/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/langchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ cassio = {version = "^0.1.0", optional = true}
rdflib = {version = "^6.3.2", optional = true}
sympy = {version = "^1.12", optional = true}
rapidfuzz = {version = "^3.1.1", optional = true}
langsmith = "~0.0.43"
langsmith = "~0.0.52"
rank-bm25 = {version = "^0.2.2", optional = true}
amadeus = {version = ">=8.1.0", optional = true}
geopandas = {version = "^0.13.1", optional = true}
Expand Down
Loading

0 comments on commit 1d568e1

Please sign in to comment.