diff --git a/jax/_src/core.py b/jax/_src/core.py index 88c3c402250d..c9ed93d63fda 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -3712,6 +3712,14 @@ def pp_vars(vs: Sequence[Atom], context: JaxprPpContext, [pp.text(pp_var(v, context)) for v in vs]) )) +def _pp_set(v: set | frozenset) -> pp.Doc: + # sorted for deterministic output in jaxpr + if not v: + return pp.text(str(v)) + sorted_reprs = [repr(x) for x in sorted(v, key=str)] + type_name = type(v).__name__ + return pp.text(f"{type_name}({{{', '.join(sorted_reprs)}}})") + def pp_kv_pair(k:str, v: Any, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc: if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v): pp_v = pp_jaxprs(v, context, settings) @@ -3719,6 +3727,8 @@ def pp_kv_pair(k:str, v: Any, context: JaxprPpContext, settings: JaxprPpSettings pp_v = pp_jaxpr(v, context, settings) elif isinstance(v, ClosedJaxpr): pp_v = pp_jaxpr(v.jaxpr, context, settings) + elif isinstance(v, (set, frozenset)): + pp_v = _pp_set(v) else: pp_v = pp.text(str(v)) return pp.text(f'{k}=') + pp_v