Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .devcontainer/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ services:
- -c
- |
tmux new-session -d "source /opt/epics-var.sh; softIocPVA -d /db/softioc.db"
tmux new-session -d "source /opt/epics-var.sh; python3 /db/k2eg-nttable-ioc.py"
tail -f /dev/null
volumes:
- ../tests/epics-test.db:/db/softioc.db
- ../tests/k2eg-nttable-ioc.py:/db/k2eg-nttable-ioc.py

k2eg:
image: ghcr.io/slaclab/k2eg/ubuntu:latest
Expand Down
6 changes: 4 additions & 2 deletions k2eg/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import logging
import threading
import datetime

from enum import Enum
from time import sleep
from readerwriterlock import rwlock
from confluent_kafka import KafkaError
from k2eg.broker import Broker, SnapshotProperties
from k2eg.serialization import MessagePackSerializable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Callable, List, Dict, Any
Expand Down Expand Up @@ -379,7 +381,7 @@ def get(self, pv_url: str, timeout: float = None):
else:
return result

def put(self, pv_url: str, value: any, timeout: float = None):
def put(self, pv_url: str, value: MessagePackSerializable, timeout: float = None):
""" Set the value for a single pv
Args:
pv_name (str): is the name of the pv
Expand All @@ -405,7 +407,7 @@ def put(self, pv_url: str, value: any, timeout: float = None):
# send message to k2eg
self.__broker.send_put_command(
pv_url,
value,
value.to_base_64(),
new_reply_id
)
while(not fetched):
Expand Down
104 changes: 104 additions & 0 deletions k2eg/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import msgpack
import base64
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple, Union

class MessagePackSerializable(ABC):
"""Base class: define msgpack (de)serialization contract."""
@abstractmethod
def to_dict(self) -> Dict[str, Any]:
"""Return a pure-Python dict representation."""
...

def to_msgpack(self) -> bytes:
"""Pack `to_dict()` into msgpack bytes."""
return msgpack.packb(self.to_dict(), use_bin_type=True)

def to_base_64(self) -> str:
"""Pack `to_dict()` into msgpack bytes, then base64 encode."""
return base64.b64encode(self.to_msgpack()).decode('utf-8')

@dataclass
class Scalar(MessagePackSerializable):
"""Wrap a single scalar value, always serialized as a map with a key."""
key: str = field(default="value")
payload: Any = field(default="")

def to_dict(self) -> Dict[str, Any]:
return {
self.key: self.payload
}


@dataclass
class Vector(MessagePackSerializable):
"""Wrap a 1D sequence of values, always serialized as a map with a key."""
key: str = field(default="value")
payload: List[Any] = field(default_factory=list)

def to_dict(self) -> Dict[str, Any]:
return {
self.key: self.payload
}


@dataclass
class Generic(MessagePackSerializable):
key: str = field(default="value")
payload: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {self.key: self.payload}

@dataclass
class NTTable(MessagePackSerializable):
"""
EPICS NTTable format:
- key: identifier
- labels: list of column names
- values: list of rows; each row is a list of values
"""
key: str = field(default="value")
labels: List[str] = field(default_factory=list)
payload: Dict[str, any] = field(default_factory=dict)

def wrap(
self,
records: List[Union[Dict[str, Any], List[Any], Tuple[Any, ...]]]
) -> "NTTable":
"""
Populate self.values from a list of dicts or sequences.
- If record is dict: extract values in order of self.labels.
- If sequence: must match len(self.labels).
"""
self.values.clear()
for rec in records:
if isinstance(rec, dict):
row = [rec[label] for label in self.labels]
elif isinstance(rec, (list, tuple)):
if len(rec) != len(self.labels):
raise ValueError("Row length does not match labels")
row = list(rec)
else:
raise TypeError("Record must be dict or sequence")
self.values.append(row)
return self

def set_column(self, label: str, data: List[Any]) -> "NTTable":
"""
Update values for an existing column.
- label: the column name (must already exist)
- data: new list of values, one per row (length must match)
"""
if label not in self.labels:
raise ValueError(f"Column label '{label}' is not defined")
self.payload[label] = data
return self

def to_dict(self) -> Dict[str, Any]:
return {
self.key: self.payload
}



49 changes: 49 additions & 0 deletions tests/k2eg-nttable-ioc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import time
from p4p.nt import NTTable, NTNDArray
from p4p.server import Server as PVAServer
from p4p.server.thread import SharedPV

# Example p4p server code modified from an original by Matt Gibbs

# Define the structure of the twiss table (what are the columns called, and what data type each column has)
twiss_table_type = NTTable([("element", "s"), ("device_name", "s"),
("s", "d"), ("z", "d"), ("length", "d"), ("p0c", "d"),
("alpha_x", "d"), ("beta_x", "d"), ("eta_x", "d"), ("etap_x", "d"), ("psi_x", "d"),
("alpha_y", "d"), ("beta_y", "d"), ("eta_y", "d"), ("etap_y", "d"), ("psi_y", "d")])

# NOTE: p4p requires you to specify some initial value for every PV - there's not a default.
# here is where we make those default values. This is a particularly hokey example.
twiss_table_rows = []
element_name_list = ["SOL9000", "XC99", "YC99"]
device_name_list = ["SOL:IN20:111", "XCOR:IN20:112", "YCOR:IN20:113"]
for i in range(0,len(element_name_list)):
element_name = element_name_list[i]
device_name = device_name_list[i]
twiss_table_rows.append({"element": element_name, "device_name": device_name, "s": 0.0, "z": 0.0, "length": 0.0, "p0c": 6.0,
"alpha_x": 0.0, "beta_x": 0.0, "eta_x": 0.0, "etap_x": 0.0, "psi_x": 0.0,
"alpha_y": 0.0, "beta_y": 0.0, "eta_y": 0.0, "etap_y": 0.0, "psi_y": 0.0})

# Take the raw data and "wrap" it into the form that the PVA server needs.
initial_twiss_table = twiss_table_type.wrap(twiss_table_rows)

# Define a "handler" that gets called when somebody puts a value into a PV.
# In our case, this is sort of silly, because we just blindly dump whatever
# a client sends into the PV object.
class Handler(object):
def put(self, pv, operation):
try:
pv.post(operation.value(), timestamp=time.time()) # just store the value and update subscribers
operation.done()
except Exception as e:
operation.done(error=str(e))

# Define the PVs that will be hosted by the server.
live_twiss_pv = SharedPV(nt=twiss_table_type, initial=initial_twiss_table, handler=Handler())
image_pv = SharedPV(handler=Handler(), nt=NTNDArray(), initial=np.zeros(1))


# Make the PVA Server. This is where we define the names for each of the PVs we defined above.
# By using "PVAServer.forever", the server immediately starts running, and doesn't stop until you
# kill it.
pva_server = PVAServer.forever(providers=[{"K2EG:TEST:TWISS": live_twiss_pv, "K2EG:TEST:IMAGE": image_pv}])
84 changes: 73 additions & 11 deletions tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
import pytest
from unittest import TestCase
from k2eg.serialization import Scalar

k: k2eg.dml = None
TestCase.maxDiff = None
Expand Down Expand Up @@ -65,7 +66,7 @@ def test_exception_on_get_with_bad_protocol():
k.get('unkonwn://', timeout=0.5)

def test_k2eg_get():
get_value = k.get('pva://channel:ramp:ramp')
get_value = k.get('pva://channel:ramp:ramp', timeout=2.0)
assert get_value is not None, "value should not be None"

def test_k2eg_get_timeout():
Expand Down Expand Up @@ -132,22 +133,22 @@ def monitor_handler(pv_name, new_value):
assert received_message_b is not False, "value should not be None"

def test_put():
k.put("pva://variable:a", 0)
k.put("pva://variable:b", 0)
k.put("pva://variable:a", Scalar(payload=0))
k.put("pva://variable:b", Scalar(payload=0))
time.sleep(1)
res_get = k.get("pva://variable:sum")
assert res_get['value'] == 0, "value should not be 0"
k.put("pva://variable:a", 2)
k.put("pva://variable:b", 2)
k.put("pva://variable:a", Scalar(payload=2))
k.put("pva://variable:b", Scalar(payload=2))
#give some time to ioc to update
time.sleep(1)
res_get = k.get("pva://variable:sum")
assert res_get['value'] == 4, "value should not be 0"

def test_multi_threading_put():
put_dic={
"pva://variable:a": 0,
"pva://variable:b": 0
"pva://variable:a": Scalar(payload=0),
"pva://variable:b": Scalar(payload=0)
}
with ThreadPoolExecutor(10) as executor:
for key, value in put_dic.items():
Expand All @@ -156,8 +157,8 @@ def test_multi_threading_put():
res_get = k.get("pva://variable:sum")
assert res_get['value'] == 0, "value should not be 0"
put_dic={
"pva://variable:a": 2,
"pva://variable:b": 2
"pva://variable:a": Scalar(payload=2),
"pva://variable:b": Scalar(payload=2)
}
with ThreadPoolExecutor(10) as executor:
for key, value in put_dic.items():
Expand All @@ -173,16 +174,77 @@ def put(key, value):
print(f"Put {key} with value {value}")
except Exception as e:
print(f"An error occured: {e}")



def test_put_timeout():
with pytest.raises(k2eg.OperationTimeout,
match=r"Timeout.*"):
k.put("pva://bad:pv:name", 0, timeout=0.5)
k.put("pva://bad:pv:name", Scalar(0), timeout=0.5)

# def test_put_nttable():
# nt_labels = [
# "element", "device_name", "s", "z", "length", "p0c",
# "alpha_x", "beta_x", "eta_x", "etap_x", "psi_x",
# "alpha_y", "beta_y", "eta_y", "etap_y", "psi_y"
# ]
# table = NTTable(labels=nt_labels)

# # 3) Add each column of data
# table.set_column("element",["SOL9000", "XC99", "YC99"])
# table.set_column("device_name",["SOL:IN20:111", "XCOR:IN20:112", "YCOR:IN20:113"])
# table.set_column("s", [0.0, 0.0, 0.0])
# table.set_column("z", [0.0, 0.0, 0.0])
# table.set_column("length", [0.0, 0.0, 0.0])
# table.set_column("p0c", [0.0, 0.0, 0.0])
# table.set_column("alpha_x", [0.0, 0.0, 0.0])
# table.set_column("beta_x", [0.0, 0.0, 0.0])
# table.set_column("eta_x", [0.0, 0.0, 0.0])
# table.set_column("etap_x", [0.0, 0.0, 0.0])
# table.set_column("psi_x", [0.0, 0.0, 0.0])
# table.set_column("alpha_y", [0.0, 0.0, 0.0])
# table.set_column("beta_y", [0.0, 0.0, 0.0])
# table.set_column("eta_y", [0.0, 0.0, 0.0])
# table.set_column("etap_y", [0.0, 0.0, 0.0])
# table.set_column("psi_y", [0.0, 0.0, 0.0])
# t_dictionary = table.to_dict()
# k.put("pva://K2EG:TEST:TWISS", table)
# # get the value to check if all ahas been set
# res_get = k.get("pva://K2EG:TEST:TWISS")
# print(res_get)
# assert res_get is not None, "value should not be None"
# assert res_get['value'] is not None, "value should not be None"
# assert res_get['value'] == t_dictionary['value'], "value should be the same as the one putted"

# # now set the vallue to increment of 1 where possible
# table.set_column("element",["1", "2", "3"])
# table.set_column("device_name",["1", "2", "3"])
# table.set_column("s", [1.0, 1.0, 1.0])
# table.set_column("z", [1.0, 1.0, 1.0])
# table.set_column("length", [1.0, 1.0, 1.0])
# table.set_column("p0c", [1.0, 1.0, 1.0])
# table.set_column("alpha_x", [1.0, 1.0, 1.0])
# table.set_column("beta_x", [1.0, 1.0, 1.0])
# table.set_column("eta_x", [1.0, 1.0, 1.0])
# table.set_column("etap_x", [1.0, 1.0, 1.0])
# table.set_column("psi_x", [1.0, 1.0, 1.0])
# table.set_column("alpha_y", [1.0, 1.0, 1.0])
# table.set_column("beta_y", [1.0, 1.0, 1.0])
# table.set_column("eta_y", [1.0, 1.0, 1.0])
# table.set_column("etap_y", [1.0, 1.0, 1.0])
# table.set_column("psi_y", [1.0, 1.0, 1.0])
# updated_t_dictionary = table.to_dict()
# k.put("pva://K2EG:TEST:TWISS", table)
# # get the value to check if all ahas been set
# res_get = k.get("pva://K2EG:TEST:TWISS")
# print(res_get)
# assert res_get is not None, "value should not be None"
# assert res_get['value'] is not None, "value should not be None"
# assert res_get['value'] == updated_t_dictionary['value'], "value should be the same as the one putted"

def test_put_wrong_device_timeout():
with pytest.raises(k2eg.OperationError):
k.put("pva://bad:pv:name", 0)
k.put("pva://bad:pv:name", Scalar(payload=0))

def test_snapshot_on_simple_fixed_pv():
retry = 0
Expand Down
39 changes: 39 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import msgpack
from k2eg.serialization import Scalar, Vector, Generic, NTTable

def test_scalar():
s = Scalar(payload=23.5)
packed_s = s.to_msgpack()
data = msgpack.unpackb(packed_s, raw=False)
# check that i got this msgpack structure {'value': 23.5}
print(f"Packed Scalar: {data}")
assert data == {'value': 23.5}

def test_vector():
v = Vector(payload=[23.5, 24.0, 22.8])
packed_v = v.to_msgpack()
data = msgpack.unpackb(packed_v, raw=False)
# check that i got this msgpack structure {'temperature': [23.5, 24.0, 22.8]}
print(f"Packed Vector: {data}")
assert data == {'value': [23.5, 24.0, 22.8]}

def test_generic():
g = Generic(payload={'key1': 'value1', 'key2': 42})
packed_g = g.to_msgpack()
data = msgpack.unpackb(packed_g, raw=False)
# check that i got this msgpack structure {'value': {'key1': 'value1', 'key2': 42}}
print(f"Packed Generic: {data}")
assert data == {'value': {'key1': 'value1', 'key2': 42}}

def test_nttable():
ntt = NTTable(
labels=["station", "anomaly_state"],
payload=[{'station': 'station_1', 'anomaly_state': True}, {'station': 'station_2', 'anomaly_state': False}]
)
packed_ntt = ntt.to_msgpack()
data = msgpack.unpackb(packed_ntt, raw=False)
# check that i got this msgpack structure
print(f"Packed NTTable: {data}")
assert data == {
'value': [{'station': 'station_1', 'anomaly_state': True}, {'station': 'station_2', 'anomaly_state': False}]
}