Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0d9923f
adds diff test code
kamiradi Feb 9, 2026
5214e1c
demo code that visualises translation and its derivative
kamiradi Feb 10, 2026
dc7fbc1
[untested] basic gradient broadcasting via query
kamiradi Feb 11, 2026
0602f55
basic testing of passing gradients over messages
kamiradi Feb 12, 2026
06159f0
queryable not recieving queries
kamiradi Feb 12, 2026
5d05ab3
simple gradient experiment
kamiradi Feb 13, 2026
0537c41
gradient querying working with ark querier and queriablew
kamiradi Feb 13, 2026
d76495e
adds readme to run gradient experiment
kamiradi Feb 13, 2026
4317c11
adds automatic gradient query handling, separates parameter nodes and…
kamiradi Feb 17, 2026
fa41a24
formatting
kamiradi Feb 17, 2026
8293fbe
removes unused files
kamiradi Feb 17, 2026
7d9c550
creates Variable class to maintain values and gradients, modifies dem…
kamiradi Feb 17, 2026
eb7165b
Variable class, modified demo, formatting
kamiradi Feb 17, 2026
54aa0b0
modifies Variable class to handle computation of gradients on query
kamiradi Feb 17, 2026
d51880b
moves queryables into Variable class
kamiradi Feb 18, 2026
5c1e5a2
reorganisation, placed Variable within variable.py
kamiradi Feb 18, 2026
cb10a47
Registers gradient query channel data on registry. Basic query discov…
kamiradi Feb 18, 2026
04f6095
impose temporal correlation between variable and gradients
kamiradi Feb 19, 2026
e5334e7
adds functionality to query gradient at a time step
kamiradi Feb 19, 2026
9c04fa8
infers leaf node in dcg, autosubscribes, backward call revealed to user
kamiradi Feb 23, 2026
96a7518
adds simple pybullet demo and gain optimizer node
kamiradi Feb 24, 2026
481501d
tests PD tuning via distributed CG
kamiradi Feb 25, 2026
22c5586
code readability
kamiradi Feb 25, 2026
c5944d4
kp gains tuninng demo
kamiradi Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/ark/comm/queriable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,37 @@ def __init__(
):
super().__init__(node_name, session, clock, channel, data_collector)
self._handler = handler
self._queryable = self._session.declare_queryable(self._channel, self._on_query)
self._queryable = self._session.declare_queryable(self._channel,
self._on_query,
complete=False)
print(f"Declared queryable on channel: {self._channel}")

def core_registration(self):
print("..todo: register with ark core..")

def _on_query(self, query: zenoh.Query) -> None:
# If we were closed, ignore queries
if not self._active:
print("Received query on closed Queryable, ignoring")
return

try:
# Zenoh query may or may not include a payload.
# For your use-case, the request is always in query.value (bytes)
raw = bytes(query.value) if query.value is not None else b""
raw = bytes(query.payload) if query.payload is not None else b""
if not raw:
print("Received query with no payload, ignoring")
return # nothing to do

req_env = Envelope()
req_env.ParseFromString(raw)

# Decode request protobuf
req_type = msgs.get(req_env.payload_msg_type)
# req_type = msgs.get(req_env.payload_msg_type)
req_type = msgs.get(req_env.msg_type)
if req_type is None:
# Unknown message type: ignore (or reply error later)
print(f"Unknown message type '{req_env.msg_type}' in query, ignoring")
return

req_msg = req_type()
Expand All @@ -60,11 +67,13 @@ def _on_query(self, query: zenoh.Query) -> None:
resp_env.sent_seq_index = self._seq_index
resp_env.src_node_name = self._node_name
resp_env.channel = self._channel
resp_env.msg_type = resp_msg.DESCRIPTOR.full_name
resp_env.payload = resp_msg.SerializeToString()

self._seq_index += 1

resp_env = Envelope.pack(self._node_name, self._clock, resp_msg)
query.reply(resp_env.SerializeToString())
with query:
query.reply(query.key_expr, resp_env.SerializeToString())

if self._data_collector:
self._data_collector.append(req_env.SerializeToString())
Expand All @@ -73,4 +82,9 @@ def _on_query(self, query: zenoh.Query) -> None:
except Exception:
# Keep it minimal: don't kill the zenoh callback thread
# You can add logging here if desired
print("Error processing query:")
# write the traceback to stdout for debugging
import traceback
traceback.print_exc()

return
21 changes: 12 additions & 9 deletions src/ark/comm/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from google.protobuf.message import Message
from ark.data.data_collector import DataCollector
from ark.comm.end_point import EndPoint
from ark_msgs.registry import msgs


class Querier(EndPoint):
Expand All @@ -11,12 +12,16 @@ def __init__(
self,
node_name: str,
session: zenoh.Session,
query_target,
clock,
channel: str,
data_collector: DataCollector | None,
):
super().__init__(node_name, session, clock, channel, data_collector)
self._querier = self._session.declare_querier(self._channel)
self._querier = self._session.declare_querier(self._channel,
target=query_target)
print(f"Declared querier on channel: {self._channel}")
self._query_selector = zenoh.Selector(self._channel)

def core_registration(self):
print("..todo: register with ark core..")
Expand Down Expand Up @@ -48,18 +53,21 @@ def query(
else:
raise TypeError("req must be a protobuf Message or bytes")

replies = self._querier.get(value=req_env.SerializeToString(), timeout=timeout)
replies = self._querier.get(payload=req_env.SerializeToString())

for reply in replies:
if reply.ok is None:
continue

resp_env = Envelope()
resp_env.ParseFromString(bytes(reply.ok))
resp_env.ParseFromString(bytes(reply.ok.payload))
resp_env.dst_node_name = self._node_name
resp_env.recv_timestamp = self._clock.now()

resp = resp_env.extract_message()
try:
resp = resp_env.extract_message()
except Exception as e:
continue

self._seq_index += 1

Expand All @@ -69,11 +77,6 @@ def query(

return resp

else:
raise TimeoutError(
f"No OK reply received for query on '{self._channel}' within {timeout}s"
)

def close(self):
super().close()
self._querier.undeclare()
1 change: 1 addition & 0 deletions src/ark/diff/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ark.diff.variable import Variable
108 changes: 108 additions & 0 deletions src/ark/diff/variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
from ark_msgs import Value


class Variable:

def __init__(self, name, value, mode, variables_registry, lock, clock, create_queryable_fn, publisher=None):
self.name = name
self.mode = mode
self._variables_registry = variables_registry
self._lock = lock
self._clock = clock
self._grads = {} # input vars: {output_name: grad_value}
self._publisher = publisher

if mode == "input":
self._tensor = torch.tensor(value, requires_grad=True)
self._history = {}
self._replay_tensor = None
else:
self._tensor = None
self._computation_ts = clock.now()
self._replay_fn = None
for inp_name, inp_var in variables_registry.items():
if inp_var.mode == "input":
grad_channel = f"grad/{inp_name}/{name}"

# create queryable and handler for this input-output
# gradient
def _make_handler(iv, ov_name, reg, lk):
def handler(_req):

# retrieve the output variable
out_var = reg.get(ov_name)

# check if replay is required
if _req.timestamp != 0 and out_var._replay_fn:
val, grad = out_var._replay_fn(_req.timestamp, iv.name, ov_name)
return Value(val=val, grad=grad, timestamp=_req.timestamp)

# if not replay, return the latest computed value
# and grad
with lk:
val = float(out_var._tensor.detach()) if out_var and out_var._tensor is not None else 0.0
grad = iv._grads.get(ov_name, 0.0)
ts = out_var._computation_ts if out_var else 0
return Value(val=val, grad=grad, timestamp=ts)
return handler

create_queryable_fn(grad_channel, _make_handler(inp_var, name, variables_registry, self._lock))

def snapshot(self, ts):
"""Record current tensor value at clock timestamp ts."""
self._history[ts] = float(self._tensor.detach())

def at(self, ts):
"""Return a fresh requires_grad tensor from history at ts."""
val = self._history[ts]
self._replay_tensor = torch.tensor(val, requires_grad=True)
return self._replay_tensor

@property
def tensor(self):
return self._tensor

@tensor.setter
def tensor(self, value):
if self.mode == "output":
self._tensor = value
if self._publisher is not None:
val = float(self._tensor.detach())
self._publisher.publish(Value(val=val, timestamp=self._clock.now()))
else:
self._tensor.data = value.data if isinstance(value, torch.Tensor) else torch.tensor(value)

def backward(self):
"""Compute and store gradients for this output variable."""
self._compute_and_store_grads()

def _is_last_output(self):
output_names = [k for k, v in self._variables_registry.items() if v.mode == "output"]
return output_names and output_names[-1] == self.name

def _compute_and_store_grads(self):
"""
Compute gradients for all input variables with respect to this output
variable
"""
if self._tensor is None or not self._tensor.requires_grad:
return
with self._lock:

# zero existing grads for all input variables to ensure correct
# backward
for var in self._variables_registry.values():
if var.mode == "input" and var._tensor.grad is not None:
var._tensor.grad.zero_()

# backward on the output tensor to compute gradients for all input
# variables
self._tensor.backward(retain_graph=not self._is_last_output())

# store computed grads for each input variable in the registry
for var in self._variables_registry.values():
if var.mode == "input":
grad = float(var._tensor.grad) if var._tensor.grad is not None else 0.0
var._grads[self.name] = grad
self._computation_ts = self._clock.now()
54 changes: 51 additions & 3 deletions src/ark/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import time
import threading
import torch
import zenoh
from ark.time.clock import Clock
from ark.time.rate import Rate
Expand All @@ -10,6 +12,8 @@
from ark.comm.queriable import Queryable
from ark.data.data_collector import DataCollector
from ark.core.registerable import Registerable
from ark.diff.variable import Variable
from ark_msgs import VariableInfo


class BaseNode(Registerable):
Expand All @@ -22,7 +26,8 @@ def __init__(
sim: bool = False,
collect_data: bool = False,
):
self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg))
# self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg))
self._z_cfg = z_cfg
self._session = zenoh.open(self._z_cfg)
self._env_name = env_name
self._node_name = node_name
Expand All @@ -36,6 +41,9 @@ def __init__(
self._subs = {}
self._queriers = {}
self._queriables = {}
self._variables = {}
self._grad_lock = threading.Lock()
self._registry_pub = self.create_publisher("ark/vars/register")

self._session.declare_subscriber(f"{env_name}/reset", self._on_reset)

Expand Down Expand Up @@ -73,17 +81,19 @@ def create_subscriber(self, channel, callback) -> Subscriber:
self._subs[channel] = sub
return sub

def create_querier(self, channel, timeout=10.0) -> Querier:
def create_querier(self, channel, target, timeout=10.0) -> Querier:
querier = Querier(
self._node_name,
self._session,
target,
self._clock,
channel,
self._data_collector,
timeout,
# timeout,
)
querier.core_registration()
self._queriers[channel] = querier
# print session and channelinfo for debugging
return querier

def create_queryable(self, channel, handler) -> Queryable:
Expand All @@ -99,6 +109,44 @@ def create_queryable(self, channel, handler) -> Queryable:
self._queriables[channel] = queryable
return queryable

def create_variable(self, name, value, mode="input", subscribe=True):
"""Create a differentiable variable.

For "input" mode (leaf nodes), a subscriber is automatically created on
"param/{name}" to receive parameter updates from the network. Set
subscribe=False to disable this.

For "output" mode, queryables are created on "grad/{input_name}/{name}"
for each existing input variable. Setting the tensor triggers an eager
backward pass that caches gradients into each input variable.

Args:
name: Variable identifier, used in channel names.
value: Initial scalar value for the underlying tensor.
mode: "input" or "output".
subscribe: If True and mode is "input", auto-subscribe to param/{name}.
"""
pub = self.create_publisher(f"output/{name}") if mode == "output" else None
var = Variable(name, value, mode, self._variables, self._grad_lock, self._clock, self.create_queryable, publisher=pub)
self._variables[name] = var

if mode == "input" and subscribe:
self.create_subscriber(f"param/{name}", lambda msg, v=var: v.tensor.data.fill_(msg.val))

if mode == "output":
grad_channels = [
f"grad/{inp_name}/{name}"
for inp_name, v in self._variables.items()
if v.mode == "input"
]
self._registry_pub.publish(VariableInfo(
output_name=name,
node_name=self._node_name,
grad_channels=grad_channels,
))

return var

def create_rate(self, hz: float):
rate = Rate(self._clock, hz)
self._rates.append(rate)
Expand Down
Loading