diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 14068de..8b3bc98 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -45,9 +45,11 @@ services: - -c - | tmux new-session -d "source /opt/epics-var.sh; softIocPVA -d /db/softioc.db" + tmux new-session -d "source /opt/epics-var.sh; python3 /db/k2eg-nttable-ioc.py" tail -f /dev/null volumes: - ../tests/epics-test.db:/db/softioc.db + - ../tests/k2eg-nttable-ioc.py:/db/k2eg-nttable-ioc.py k2eg: image: ghcr.io/slaclab/k2eg/ubuntu:latest diff --git a/k2eg/dml.py b/k2eg/dml.py index d28da65..8611c32 100644 --- a/k2eg/dml.py +++ b/k2eg/dml.py @@ -4,11 +4,13 @@ import logging import threading import datetime + from enum import Enum from time import sleep from readerwriterlock import rwlock from confluent_kafka import KafkaError from k2eg.broker import Broker, SnapshotProperties +from k2eg.serialization import MessagePackSerializable from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Callable, List, Dict, Any @@ -379,7 +381,7 @@ def get(self, pv_url: str, timeout: float = None): else: return result - def put(self, pv_url: str, value: any, timeout: float = None): + def put(self, pv_url: str, value: MessagePackSerializable, timeout: float = None): """ Set the value for a single pv Args: pv_name (str): is the name of the pv @@ -405,7 +407,7 @@ def put(self, pv_url: str, value: any, timeout: float = None): # send message to k2eg self.__broker.send_put_command( pv_url, - value, + value.to_base_64(), new_reply_id ) while(not fetched): diff --git a/k2eg/serialization.py b/k2eg/serialization.py new file mode 100644 index 0000000..da091f3 --- /dev/null +++ b/k2eg/serialization.py @@ -0,0 +1,104 @@ +import msgpack +import base64 +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple, Union + +class MessagePackSerializable(ABC): + """Base class: define msgpack (de)serialization contract.""" + @abstractmethod + def to_dict(self) -> Dict[str, Any]: + """Return a pure-Python dict representation.""" + ... + + def to_msgpack(self) -> bytes: + """Pack `to_dict()` into msgpack bytes.""" + return msgpack.packb(self.to_dict(), use_bin_type=True) + + def to_base_64(self) -> str: + """Pack `to_dict()` into msgpack bytes, then base64 encode.""" + return base64.b64encode(self.to_msgpack()).decode('utf-8') + +@dataclass +class Scalar(MessagePackSerializable): + """Wrap a single scalar value, always serialized as a map with a key.""" + key: str = field(default="value") + payload: Any = field(default="") + + def to_dict(self) -> Dict[str, Any]: + return { + self.key: self.payload + } + + +@dataclass +class Vector(MessagePackSerializable): + """Wrap a 1D sequence of values, always serialized as a map with a key.""" + key: str = field(default="value") + payload: List[Any] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + self.key: self.payload + } + + +@dataclass +class Generic(MessagePackSerializable): + key: str = field(default="value") + payload: Dict[str, Any] = field(default_factory=dict) + def to_dict(self) -> Dict[str, Any]: + return {self.key: self.payload} + +@dataclass +class NTTable(MessagePackSerializable): + """ + EPICS NTTable format: + - key: identifier + - labels: list of column names + - values: list of rows; each row is a list of values + """ + key: str = field(default="value") + labels: List[str] = field(default_factory=list) + payload: Dict[str, any] = field(default_factory=dict) + + def wrap( + self, + records: List[Union[Dict[str, Any], List[Any], Tuple[Any, ...]]] + ) -> "NTTable": + """ + Populate self.values from a list of dicts or sequences. + - If record is dict: extract values in order of self.labels. + - If sequence: must match len(self.labels). + """ + self.values.clear() + for rec in records: + if isinstance(rec, dict): + row = [rec[label] for label in self.labels] + elif isinstance(rec, (list, tuple)): + if len(rec) != len(self.labels): + raise ValueError("Row length does not match labels") + row = list(rec) + else: + raise TypeError("Record must be dict or sequence") + self.values.append(row) + return self + + def set_column(self, label: str, data: List[Any]) -> "NTTable": + """ + Update values for an existing column. + - label: the column name (must already exist) + - data: new list of values, one per row (length must match) + """ + if label not in self.labels: + raise ValueError(f"Column label '{label}' is not defined") + self.payload[label] = data + return self + + def to_dict(self) -> Dict[str, Any]: + return { + self.key: self.payload + } + + + \ No newline at end of file diff --git a/tests/k2eg-nttable-ioc.py b/tests/k2eg-nttable-ioc.py new file mode 100644 index 0000000..edcd386 --- /dev/null +++ b/tests/k2eg-nttable-ioc.py @@ -0,0 +1,49 @@ +import numpy as np +import time +from p4p.nt import NTTable, NTNDArray +from p4p.server import Server as PVAServer +from p4p.server.thread import SharedPV + +# Example p4p server code modified from an original by Matt Gibbs + +# Define the structure of the twiss table (what are the columns called, and what data type each column has) +twiss_table_type = NTTable([("element", "s"), ("device_name", "s"), + ("s", "d"), ("z", "d"), ("length", "d"), ("p0c", "d"), + ("alpha_x", "d"), ("beta_x", "d"), ("eta_x", "d"), ("etap_x", "d"), ("psi_x", "d"), + ("alpha_y", "d"), ("beta_y", "d"), ("eta_y", "d"), ("etap_y", "d"), ("psi_y", "d")]) + +# NOTE: p4p requires you to specify some initial value for every PV - there's not a default. +# here is where we make those default values. This is a particularly hokey example. +twiss_table_rows = [] +element_name_list = ["SOL9000", "XC99", "YC99"] +device_name_list = ["SOL:IN20:111", "XCOR:IN20:112", "YCOR:IN20:113"] +for i in range(0,len(element_name_list)): + element_name = element_name_list[i] + device_name = device_name_list[i] + twiss_table_rows.append({"element": element_name, "device_name": device_name, "s": 0.0, "z": 0.0, "length": 0.0, "p0c": 6.0, + "alpha_x": 0.0, "beta_x": 0.0, "eta_x": 0.0, "etap_x": 0.0, "psi_x": 0.0, + "alpha_y": 0.0, "beta_y": 0.0, "eta_y": 0.0, "etap_y": 0.0, "psi_y": 0.0}) + +# Take the raw data and "wrap" it into the form that the PVA server needs. +initial_twiss_table = twiss_table_type.wrap(twiss_table_rows) + +# Define a "handler" that gets called when somebody puts a value into a PV. +# In our case, this is sort of silly, because we just blindly dump whatever +# a client sends into the PV object. +class Handler(object): + def put(self, pv, operation): + try: + pv.post(operation.value(), timestamp=time.time()) # just store the value and update subscribers + operation.done() + except Exception as e: + operation.done(error=str(e)) + +# Define the PVs that will be hosted by the server. +live_twiss_pv = SharedPV(nt=twiss_table_type, initial=initial_twiss_table, handler=Handler()) +image_pv = SharedPV(handler=Handler(), nt=NTNDArray(), initial=np.zeros(1)) + + +# Make the PVA Server. This is where we define the names for each of the PVs we defined above. +# By using "PVAServer.forever", the server immediately starts running, and doesn't stop until you +# kill it. +pva_server = PVAServer.forever(providers=[{"K2EG:TEST:TWISS": live_twiss_pv, "K2EG:TEST:IMAGE": image_pv}]) diff --git a/tests/test_dml.py b/tests/test_dml.py index b8e3710..26f3169 100644 --- a/tests/test_dml.py +++ b/tests/test_dml.py @@ -8,6 +8,7 @@ import time import pytest from unittest import TestCase +from k2eg.serialization import Scalar k: k2eg.dml = None TestCase.maxDiff = None @@ -65,7 +66,7 @@ def test_exception_on_get_with_bad_protocol(): k.get('unkonwn://', timeout=0.5) def test_k2eg_get(): - get_value = k.get('pva://channel:ramp:ramp') + get_value = k.get('pva://channel:ramp:ramp', timeout=2.0) assert get_value is not None, "value should not be None" def test_k2eg_get_timeout(): @@ -132,13 +133,13 @@ def monitor_handler(pv_name, new_value): assert received_message_b is not False, "value should not be None" def test_put(): - k.put("pva://variable:a", 0) - k.put("pva://variable:b", 0) + k.put("pva://variable:a", Scalar(payload=0)) + k.put("pva://variable:b", Scalar(payload=0)) time.sleep(1) res_get = k.get("pva://variable:sum") assert res_get['value'] == 0, "value should not be 0" - k.put("pva://variable:a", 2) - k.put("pva://variable:b", 2) + k.put("pva://variable:a", Scalar(payload=2)) + k.put("pva://variable:b", Scalar(payload=2)) #give some time to ioc to update time.sleep(1) res_get = k.get("pva://variable:sum") @@ -146,8 +147,8 @@ def test_put(): def test_multi_threading_put(): put_dic={ - "pva://variable:a": 0, - "pva://variable:b": 0 + "pva://variable:a": Scalar(payload=0), + "pva://variable:b": Scalar(payload=0) } with ThreadPoolExecutor(10) as executor: for key, value in put_dic.items(): @@ -156,8 +157,8 @@ def test_multi_threading_put(): res_get = k.get("pva://variable:sum") assert res_get['value'] == 0, "value should not be 0" put_dic={ - "pva://variable:a": 2, - "pva://variable:b": 2 + "pva://variable:a": Scalar(payload=2), + "pva://variable:b": Scalar(payload=2) } with ThreadPoolExecutor(10) as executor: for key, value in put_dic.items(): @@ -173,16 +174,77 @@ def put(key, value): print(f"Put {key} with value {value}") except Exception as e: print(f"An error occured: {e}") + + def test_put_timeout(): with pytest.raises(k2eg.OperationTimeout, match=r"Timeout.*"): - k.put("pva://bad:pv:name", 0, timeout=0.5) + k.put("pva://bad:pv:name", Scalar(0), timeout=0.5) +# def test_put_nttable(): +# nt_labels = [ +# "element", "device_name", "s", "z", "length", "p0c", +# "alpha_x", "beta_x", "eta_x", "etap_x", "psi_x", +# "alpha_y", "beta_y", "eta_y", "etap_y", "psi_y" +# ] +# table = NTTable(labels=nt_labels) + +# # 3) Add each column of data +# table.set_column("element",["SOL9000", "XC99", "YC99"]) +# table.set_column("device_name",["SOL:IN20:111", "XCOR:IN20:112", "YCOR:IN20:113"]) +# table.set_column("s", [0.0, 0.0, 0.0]) +# table.set_column("z", [0.0, 0.0, 0.0]) +# table.set_column("length", [0.0, 0.0, 0.0]) +# table.set_column("p0c", [0.0, 0.0, 0.0]) +# table.set_column("alpha_x", [0.0, 0.0, 0.0]) +# table.set_column("beta_x", [0.0, 0.0, 0.0]) +# table.set_column("eta_x", [0.0, 0.0, 0.0]) +# table.set_column("etap_x", [0.0, 0.0, 0.0]) +# table.set_column("psi_x", [0.0, 0.0, 0.0]) +# table.set_column("alpha_y", [0.0, 0.0, 0.0]) +# table.set_column("beta_y", [0.0, 0.0, 0.0]) +# table.set_column("eta_y", [0.0, 0.0, 0.0]) +# table.set_column("etap_y", [0.0, 0.0, 0.0]) +# table.set_column("psi_y", [0.0, 0.0, 0.0]) +# t_dictionary = table.to_dict() +# k.put("pva://K2EG:TEST:TWISS", table) +# # get the value to check if all ahas been set +# res_get = k.get("pva://K2EG:TEST:TWISS") +# print(res_get) +# assert res_get is not None, "value should not be None" +# assert res_get['value'] is not None, "value should not be None" +# assert res_get['value'] == t_dictionary['value'], "value should be the same as the one putted" + +# # now set the vallue to increment of 1 where possible +# table.set_column("element",["1", "2", "3"]) +# table.set_column("device_name",["1", "2", "3"]) +# table.set_column("s", [1.0, 1.0, 1.0]) +# table.set_column("z", [1.0, 1.0, 1.0]) +# table.set_column("length", [1.0, 1.0, 1.0]) +# table.set_column("p0c", [1.0, 1.0, 1.0]) +# table.set_column("alpha_x", [1.0, 1.0, 1.0]) +# table.set_column("beta_x", [1.0, 1.0, 1.0]) +# table.set_column("eta_x", [1.0, 1.0, 1.0]) +# table.set_column("etap_x", [1.0, 1.0, 1.0]) +# table.set_column("psi_x", [1.0, 1.0, 1.0]) +# table.set_column("alpha_y", [1.0, 1.0, 1.0]) +# table.set_column("beta_y", [1.0, 1.0, 1.0]) +# table.set_column("eta_y", [1.0, 1.0, 1.0]) +# table.set_column("etap_y", [1.0, 1.0, 1.0]) +# table.set_column("psi_y", [1.0, 1.0, 1.0]) +# updated_t_dictionary = table.to_dict() +# k.put("pva://K2EG:TEST:TWISS", table) +# # get the value to check if all ahas been set +# res_get = k.get("pva://K2EG:TEST:TWISS") +# print(res_get) +# assert res_get is not None, "value should not be None" +# assert res_get['value'] is not None, "value should not be None" +# assert res_get['value'] == updated_t_dictionary['value'], "value should be the same as the one putted" def test_put_wrong_device_timeout(): with pytest.raises(k2eg.OperationError): - k.put("pva://bad:pv:name", 0) + k.put("pva://bad:pv:name", Scalar(payload=0)) def test_snapshot_on_simple_fixed_pv(): retry = 0 diff --git a/tests/test_serialization.py b/tests/test_serialization.py new file mode 100644 index 0000000..b180313 --- /dev/null +++ b/tests/test_serialization.py @@ -0,0 +1,39 @@ +import msgpack +from k2eg.serialization import Scalar, Vector, Generic, NTTable + +def test_scalar(): + s = Scalar(payload=23.5) + packed_s = s.to_msgpack() + data = msgpack.unpackb(packed_s, raw=False) + # check that i got this msgpack structure {'value': 23.5} + print(f"Packed Scalar: {data}") + assert data == {'value': 23.5} + +def test_vector(): + v = Vector(payload=[23.5, 24.0, 22.8]) + packed_v = v.to_msgpack() + data = msgpack.unpackb(packed_v, raw=False) + # check that i got this msgpack structure {'temperature': [23.5, 24.0, 22.8]} + print(f"Packed Vector: {data}") + assert data == {'value': [23.5, 24.0, 22.8]} + +def test_generic(): + g = Generic(payload={'key1': 'value1', 'key2': 42}) + packed_g = g.to_msgpack() + data = msgpack.unpackb(packed_g, raw=False) + # check that i got this msgpack structure {'value': {'key1': 'value1', 'key2': 42}} + print(f"Packed Generic: {data}") + assert data == {'value': {'key1': 'value1', 'key2': 42}} + +def test_nttable(): + ntt = NTTable( + labels=["station", "anomaly_state"], + payload=[{'station': 'station_1', 'anomaly_state': True}, {'station': 'station_2', 'anomaly_state': False}] + ) + packed_ntt = ntt.to_msgpack() + data = msgpack.unpackb(packed_ntt, raw=False) + # check that i got this msgpack structure + print(f"Packed NTTable: {data}") + assert data == { + 'value': [{'station': 'station_1', 'anomaly_state': True}, {'station': 'station_2', 'anomaly_state': False}] + } \ No newline at end of file