Skip to content

Commit f49c355

Browse files
authored
Merge pull request #273 from hydrogenair/issue-204-implementation
fix: implement and improve the privacy LLM framework (#204)
2 parents e2091a6 + b218a88 commit f49c355

File tree

18 files changed

+2165
-0
lines changed

18 files changed

+2165
-0
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
benchmarkingjob:
2+
name: benchmarkingjob
3+
namespace: default
4+
description: ianvs-based cloud-edge collaborative inference benchmark with Sedna
5+
testenv: "./examples/TAB/cloud_edge_collaborative_inference_bench/test_env/test_env.yaml"
6+
7+
8+
test_object:
9+
type: "algorithms"
10+
algorithms:
11+
- name: "privacy-aware-query-routing"
12+
url: "./examples/TAB/cloud_edge_collaborative_inference_bench/test_algorithms/test_algorithms.yaml"
13+
14+
task:
15+
type: "sedna_collaborative_inference"
16+
algorithm:
17+
name: "privacy-aware-query-routing"
18+
url: "./examples/TAB/cloud_edge_collaborative_inference_bench/test_algorithms/test_algorithms.yaml"
19+
20+
21+
22+
evaluation:
23+
metrics:
24+
- name: "privacy_metrics"
25+
url: "./examples/TAB/cloud_edge_collaborative_inference_bench/test_env/privacy_metrics.py"
26+
- name: "performance_metrics"
27+
url: "./examples/TAB/cloud_edge_collaborative_inference_bench/test_env/performance_metrics.py"
28+
29+
rank:
30+
sort_by:
31+
- { privacy_score: "descend" }
32+
- { apr: "descend" }
33+
visualization:
34+
mode: "selected_only"
35+
method: "print_table"
36+
selected_dataitem:
37+
paradigms: ["all"]
38+
modules: ["all"]
39+
hyperparameters: ["all"]
40+
metrics: ["all"]
41+
save_mode: "selected_and_all"
42+
43+
44+
runtime:
45+
framework: "sedna"
46+
max_parallel_tasks: 10
47+
timeout: 300
48+
log_level: "INFO"
49+
sedna:
50+
worker_nodes: ["edge-node-1", "edge-node-2"]
51+
coordinator_node: "cloud-node"
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import json
2+
import os
3+
import numpy as np
4+
from typing import List, Dict
5+
from sedna.datasources import BaseDataSource
6+
from sedna.common.class_factory import ClassFactory, ClassType
7+
import logging
8+
9+
logging.basicConfig(level=logging.INFO)
10+
11+
@ClassFactory.register(ClassType.GENERAL, alias="ECHRDataProcessor")
12+
class ECHRDataProcessor:
13+
14+
def __init__(self, cache_dir="./cache/echr_dataset", **kwargs):
15+
self.cache_dir = cache_dir
16+
os.makedirs(self.cache_dir, exist_ok=True)
17+
self.processed_data = None
18+
19+
def __call__(self, dataset: BaseDataSource) -> BaseDataSource:
20+
"""Transform the dataset for ECHR data processing"""
21+
try:
22+
def parse_question_embedded(q: str):
23+
text, annotations = q, {}
24+
if not isinstance(q, str):
25+
return text, annotations
26+
marker = "\n\nANNOTATIONS_JSON="
27+
if marker in q:
28+
head, tail = q.rsplit(marker, 1)
29+
text = head.strip()
30+
try:
31+
annotations = json.loads(tail.strip())
32+
except Exception:
33+
annotations = {}
34+
return text, annotations
35+
if "<ANN>" in q and "</ANN>" in q:
36+
try:
37+
head, rest = q.split("<ANN>", 1)
38+
json_str, _ = rest.split("</ANN>", 1)
39+
text = head.strip()
40+
annotations = json.loads(json_str.strip())
41+
except Exception:
42+
annotations = {}
43+
return text, annotations
44+
return text, annotations
45+
46+
processed_items = []
47+
48+
x_data = dataset.x
49+
if isinstance(x_data, np.ndarray):
50+
x_data = x_data.tolist()
51+
52+
53+
y_data = dataset.y
54+
if isinstance(y_data, np.ndarray):
55+
y_data = y_data.tolist()
56+
57+
for x, y in zip(x_data, y_data):
58+
59+
if isinstance(x, dict):
60+
if "question" in x:
61+
text_val, ann = parse_question_embedded(x.get("question", ""))
62+
doc = {
63+
"text": text_val,
64+
"annotations": ann or x.get("annotations", {}),
65+
"doc_id": x.get("doc_id")
66+
}
67+
else:
68+
doc = {
69+
"text": x.get("text", ""),
70+
"annotations": x.get("annotations", {}),
71+
"doc_id": x.get("doc_id")
72+
}
73+
else:
74+
text = str(x) if not isinstance(x, str) else x
75+
text_val, ann = parse_question_embedded(text)
76+
if not ann:
77+
if not isinstance(y, dict):
78+
try:
79+
y = json.loads(y) if isinstance(y, str) else {}
80+
except:
81+
y = {}
82+
ann = y.get("annotations", {})
83+
doc = {
84+
"text": text_val,
85+
"annotations": ann,
86+
"doc_id": (y.get("doc_id") if isinstance(y, dict) else None)
87+
}
88+
89+
sensitive_entities = []
90+
total_mentions = 0
91+
id_type_stats = {}
92+
for annotator in doc.get("annotations", {}).values():
93+
for entity in annotator.get("entity_mentions", []):
94+
total_mentions += 1
95+
id_type = entity.get("identifier_type")
96+
if id_type:
97+
id_type_stats[id_type] = id_type_stats.get(id_type, 0) + 1
98+
if entity.get("identifier_type") in ["DIRECT", "QUASI"]:
99+
sensitive_entities.append({
100+
"span_text": entity["span_text"],
101+
"entity_type": entity["entity_type"],
102+
"start_offset": entity["start_offset"],
103+
"end_offset": entity["end_offset"],
104+
"sensitivity": 5 if entity["identifier_type"] == "DIRECT" else 3,
105+
"identifier_type": entity["identifier_type"],
106+
"entity_id": entity["entity_id"]
107+
})
108+
109+
processed_item = {
110+
"text": doc["text"],
111+
"doc_id": doc.get("doc_id"),
112+
"sensitive_entities": sensitive_entities,
113+
"raw_doc": doc
114+
}
115+
logging.info(
116+
f"Processed doc_id {doc.get('doc_id')}: "
117+
f"mentions_total={total_mentions}, id_type_stats={id_type_stats}, "
118+
f"selected_sensitive={len(sensitive_entities)}"
119+
)
120+
121+
122+
processed_items.append(processed_item)
123+
124+
dataset.x = processed_items
125+
self.processed_data = processed_items
126+
except Exception as e:
127+
raise RuntimeError(f"Failed to transform dataset for ECHR Data Processor: {e}") from e
128+
129+
return dataset
130+
131+
def process(self, dataset: BaseDataSource) -> List[Dict]:
132+
133+
self.processed_data = []
134+
135+
x_data = dataset.x
136+
if isinstance(x_data, np.ndarray):
137+
x_data = x_data.tolist()
138+
139+
y_data = dataset.y
140+
if isinstance(y_data, np.ndarray):
141+
y_data = y_data.tolist()
142+
143+
for x, y in zip(x_data, y_data):
144+
if isinstance(x, dict):
145+
if "question" in x:
146+
def parse_question_embedded_local(q: str):
147+
marker = "\n\nANNOTATIONS_JSON="
148+
if isinstance(q, str) and marker in q:
149+
head, tail = q.rsplit(marker, 1)
150+
try:
151+
return head.strip(), json.loads(tail.strip())
152+
except Exception:
153+
return head.strip(), {}
154+
if isinstance(q, str) and ("<ANN>" in q and "</ANN>" in q):
155+
try:
156+
head, rest = q.split("<ANN>", 1)
157+
json_str, _ = rest.split("</ANN>", 1)
158+
return head.strip(), json.loads(json_str.strip())
159+
except Exception:
160+
return q, {}
161+
return q, {}
162+
text_val, ann = parse_question_embedded_local(x.get("question", ""))
163+
doc = {
164+
"text": text_val,
165+
"annotations": ann or x.get("annotations", {}),
166+
"doc_id": x.get("doc_id")
167+
}
168+
else:
169+
doc = {
170+
"text": x.get("text", ""),
171+
"annotations": x.get("annotations", {}),
172+
"doc_id": x.get("doc_id")
173+
}
174+
else:
175+
text = str(x) if not isinstance(x, str) else x
176+
text_val, ann = text, {}
177+
marker = "\n\nANNOTATIONS_JSON="
178+
if isinstance(text, str) and marker in text:
179+
head, tail = text.rsplit(marker, 1)
180+
text_val = head.strip()
181+
try:
182+
ann = json.loads(tail.strip())
183+
except Exception:
184+
ann = {}
185+
elif isinstance(text, str) and ("<ANN>" in text and "</ANN>" in text):
186+
try:
187+
head, rest = text.split("<ANN>", 1)
188+
json_str, _ = rest.split("</ANN>", 1)
189+
text_val = head.strip()
190+
ann = json.loads(json_str.strip())
191+
except Exception:
192+
ann = {}
193+
if not ann:
194+
195+
if not isinstance(y, dict):
196+
try:
197+
y = json.loads(y) if isinstance(y, str) else {}
198+
except:
199+
y = {}
200+
ann = y.get("annotations", {})
201+
doc = {
202+
"text": text_val,
203+
"annotations": ann,
204+
"doc_id": (y.get("doc_id") if isinstance(y, dict) else None)
205+
}
206+
207+
sensitive_entities = []
208+
total_mentions = 0
209+
id_type_stats = {}
210+
for annotator in doc.get("annotations", {}).values():
211+
for entity in annotator.get("entity_mentions", []):
212+
total_mentions += 1
213+
id_type = entity.get("identifier_type")
214+
if id_type:
215+
id_type_stats[id_type] = id_type_stats.get(id_type, 0) + 1
216+
if entity.get("identifier_type") in ["DIRECT", "QUASI"]:
217+
sensitive_entities.append({
218+
"span_text": entity["span_text"],
219+
"entity_type": entity["entity_type"],
220+
"start_offset": entity["start_offset"],
221+
"end_offset": entity["end_offset"],
222+
"sensitivity": 5 if entity["identifier_type"] == "DIRECT" else 3,
223+
"identifier_type": entity["identifier_type"],
224+
"entity_id": entity["entity_id"]
225+
})
226+
227+
processed_item = {
228+
"text": doc["text"],
229+
"doc_id": doc.get("doc_id"),
230+
"sensitive_entities": sensitive_entities,
231+
"raw_doc": doc
232+
}
233+
logging.info(
234+
f"Processed(doc mode=process) doc_id {doc.get('doc_id')}: "
235+
f"mentions_total={total_mentions}, id_type_stats={id_type_stats}, "
236+
f"selected_sensitive={len(sensitive_entities)}"
237+
)
238+
print(f"processed_item: {processed_item}")
239+
self.processed_data.append(processed_item)
240+
241+
return self.processed_data
242+
243+
def get_processed_data(self) -> List[Dict]:
244+
if self.processed_data is None:
245+
raise ValueError("Please call process() to process data first")
246+
return self.processed_data
104 KB
Loading

0 commit comments

Comments
 (0)