Skip to content

Commit bfc6d1a

Browse files
puririshi98pre-commit-ci[bot]akihironittarusty1s
authored
Add WebQSPDataset (#9481)
1. #9462 2. #9480 3. **->** #9481 4. #9167 --- Breaking down PR #9167 further --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: rusty1s <[email protected]>
1 parent 6d9e850 commit bfc6d1a

File tree

5 files changed

+257
-6
lines changed

5 files changed

+257
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10+
- Added the `WebQSPDataset` dataset ([#9481](https://github.com/pyg-team/pytorch_geometric/pull/9481))
1011
- Added the `GRetriever` model ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480))
1112
- Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))
1213
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))

test/nn/nlp/test_sentence_transformer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@ def test_sentence_transformer(batch_size, pooling_strategy, device):
2323
]
2424

2525
out = model.encode(text, batch_size=batch_size)
26+
assert out.device == device
27+
assert out.size() == (2, 128)
28+
29+
out = model.encode(text, batch_size=batch_size, output_device='cpu')
2630
assert out.is_cpu
2731
assert out.size() == (2, 128)
2832

29-
out = model.encode(text, batch_size=batch_size, output_device=device)
33+
out = model.encode([], batch_size=batch_size)
3034
assert out.device == device
31-
assert out.size() == (2, 128)
35+
assert out.size() == (0, 128)

torch_geometric/datasets/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
from .twitch import Twitch
6262
from .airports import Airports
6363
from .lrgb import LRGBDataset
64-
from .neurograph import NeuroGraphDataset
6564
from .malnet_tiny import MalNetTiny
6665
from .omdb import OMDB
6766
from .polblogs import PolBlogs
@@ -76,6 +75,8 @@
7675
from .wikidata import Wikidata5M
7776
from .myket import MyketDataset
7877
from .brca_tgca import BrcaTcga
78+
from .neurograph import NeuroGraphDataset
79+
from .web_qsp_dataset import WebQSPDataset
7980

8081
from .dbp15k import DBP15K
8182
from .aminer import AMiner
@@ -188,6 +189,7 @@
188189
'MyketDataset',
189190
'BrcaTcga',
190191
'NeuroGraphDataset',
192+
'WebQSPDataset',
191193
]
192194

193195
hetero_datasets = [
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630
2+
from typing import Any, Dict, List, Tuple, no_type_check
3+
4+
import numpy as np
5+
import torch
6+
from torch import Tensor
7+
from tqdm import tqdm
8+
9+
from torch_geometric.data import Data, InMemoryDataset
10+
from torch_geometric.nn.nlp import SentenceTransformer
11+
12+
13+
@no_type_check
14+
def retrieval_via_pcst(
15+
data: Data,
16+
q_emb: Tensor,
17+
textual_nodes: Any,
18+
textual_edges: Any,
19+
topk: int = 3,
20+
topk_e: int = 3,
21+
cost_e: float = 0.5,
22+
) -> Tuple[Data, str]:
23+
c = 0.01
24+
if len(textual_nodes) == 0 or len(textual_edges) == 0:
25+
desc = textual_nodes.to_csv(index=False) + "\n" + textual_edges.to_csv(
26+
index=False,
27+
columns=["src", "edge_attr", "dst"],
28+
)
29+
return data, desc
30+
31+
from pcst_fast import pcst_fast
32+
33+
root = -1
34+
num_clusters = 1
35+
pruning = 'gw'
36+
verbosity_level = 0
37+
if topk > 0:
38+
n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
39+
topk = min(topk, data.num_nodes)
40+
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
41+
42+
n_prizes = torch.zeros_like(n_prizes)
43+
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
44+
else:
45+
n_prizes = torch.zeros(data.num_nodes)
46+
47+
if topk_e > 0:
48+
e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
49+
topk_e = min(topk_e, e_prizes.unique().size(0))
50+
51+
topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
52+
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
53+
last_topk_e_value = topk_e
54+
for k in range(topk_e):
55+
indices = e_prizes == topk_e_values[k]
56+
value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
57+
e_prizes[indices] = value
58+
last_topk_e_value = value * (1 - c)
59+
# reduce the cost of the edges such that at least one edge is selected
60+
cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
61+
else:
62+
e_prizes = torch.zeros(data.num_edges)
63+
64+
costs = []
65+
edges = []
66+
virtual_n_prizes = []
67+
virtual_edges = []
68+
virtual_costs = []
69+
mapping_n = {}
70+
mapping_e = {}
71+
for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
72+
prize_e = e_prizes[i]
73+
if prize_e <= cost_e:
74+
mapping_e[len(edges)] = i
75+
edges.append((src, dst))
76+
costs.append(cost_e - prize_e)
77+
else:
78+
virtual_node_id = data.num_nodes + len(virtual_n_prizes)
79+
mapping_n[virtual_node_id] = i
80+
virtual_edges.append((src, virtual_node_id))
81+
virtual_edges.append((virtual_node_id, dst))
82+
virtual_costs.append(0)
83+
virtual_costs.append(0)
84+
virtual_n_prizes.append(prize_e - cost_e)
85+
86+
prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
87+
num_edges = len(edges)
88+
if len(virtual_costs) > 0:
89+
costs = np.array(costs + virtual_costs)
90+
edges = np.array(edges + virtual_edges)
91+
92+
vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
93+
pruning, verbosity_level)
94+
95+
selected_nodes = vertices[vertices < data.num_nodes]
96+
selected_edges = [mapping_e[e] for e in edges if e < num_edges]
97+
virtual_vertices = vertices[vertices >= data.num_nodes]
98+
if len(virtual_vertices) > 0:
99+
virtual_vertices = vertices[vertices >= data.num_nodes]
100+
virtual_edges = [mapping_n[i] for i in virtual_vertices]
101+
selected_edges = np.array(selected_edges + virtual_edges)
102+
103+
edge_index = data.edge_index[:, selected_edges]
104+
selected_nodes = np.unique(
105+
np.concatenate(
106+
[selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))
107+
108+
n = textual_nodes.iloc[selected_nodes]
109+
e = textual_edges.iloc[selected_edges]
110+
desc = n.to_csv(index=False) + '\n' + e.to_csv(
111+
index=False, columns=['src', 'edge_attr', 'dst'])
112+
113+
mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
114+
src = [mapping[i] for i in edge_index[0].tolist()]
115+
dst = [mapping[i] for i in edge_index[1].tolist()]
116+
117+
data = Data(
118+
x=data.x[selected_nodes],
119+
edge_index=torch.tensor([src, dst]),
120+
edge_attr=data.edge_attr[selected_edges],
121+
)
122+
123+
return data, desc
124+
125+
126+
class WebQSPDataset(InMemoryDataset):
127+
r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse
128+
Labeling for Knowledge Base Question Answering"
129+
<https://aclanthology.org/P16-2033/>`_ paper.
130+
131+
Args:
132+
root (str): Root directory where the dataset should be saved.
133+
split (str, optional): If :obj:`"train"`, loads the training dataset.
134+
If :obj:`"val"`, loads the validation dataset.
135+
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
136+
force_reload (bool, optional): Whether to re-process the dataset.
137+
(default: :obj:`False`)
138+
"""
139+
def __init__(
140+
self,
141+
root: str,
142+
split: str = "train",
143+
force_reload: bool = False,
144+
) -> None:
145+
super().__init__(root, force_reload=force_reload)
146+
147+
if split not in {'train', 'val', 'test'}:
148+
raise ValueError(f"Invalid 'split' argument (got {split})")
149+
150+
path = self.processed_paths[['train', 'val', 'test'].index(split)]
151+
self.load(path)
152+
153+
@property
154+
def processed_file_names(self) -> List[str]:
155+
return ['train_data.pt', 'val_data.pt', 'test_data.pt']
156+
157+
def process(self) -> None:
158+
import datasets
159+
import pandas as pd
160+
161+
datasets = datasets.load_dataset('rmanluo/RoG-webqsp')
162+
163+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
164+
model_name = 'sentence-transformers/all-roberta-large-v1'
165+
model = SentenceTransformer(model_name).to(device)
166+
model.eval()
167+
168+
for dataset, path in zip(
169+
[datasets['train'], datasets['validation'], datasets['test']],
170+
self.processed_paths,
171+
):
172+
questions = [example["question"] for example in dataset]
173+
question_embs = model.encode(
174+
questions,
175+
batch_size=256,
176+
output_device='cpu',
177+
)
178+
179+
data_list = []
180+
for i, example in enumerate(tqdm(dataset)):
181+
raw_nodes: Dict[str, int] = {}
182+
raw_edges = []
183+
for tri in example["graph"]:
184+
h, r, t = tri
185+
h = h.lower()
186+
t = t.lower()
187+
if h not in raw_nodes:
188+
raw_nodes[h] = len(raw_nodes)
189+
if t not in raw_nodes:
190+
raw_nodes[t] = len(raw_nodes)
191+
raw_edges.append({
192+
"src": raw_nodes[h],
193+
"edge_attr": r,
194+
"dst": raw_nodes[t]
195+
})
196+
nodes = pd.DataFrame([{
197+
"node_id": v,
198+
"node_attr": k,
199+
} for k, v in raw_nodes.items()])
200+
edges = pd.DataFrame(raw_edges)
201+
202+
nodes.node_attr = nodes.node_attr.fillna("")
203+
x = model.encode(
204+
nodes.node_attr.tolist(),
205+
batch_size=256,
206+
output_device='cpu',
207+
)
208+
edge_attr = model.encode(
209+
edges.edge_attr.tolist(),
210+
batch_size=256,
211+
output_device='cpu',
212+
)
213+
edge_index = torch.tensor([
214+
edges.src.tolist(),
215+
edges.dst.tolist(),
216+
])
217+
218+
question = f"Question: {example['question']}\nAnswer: "
219+
label = ('|').join(example['answer']).lower()
220+
data = Data(
221+
x=x,
222+
edge_index=edge_index,
223+
edge_attr=edge_attr,
224+
)
225+
data, desc = retrieval_via_pcst(
226+
data,
227+
question_embs[i],
228+
nodes,
229+
edges,
230+
topk=3,
231+
topk_e=5,
232+
cost_e=0.5,
233+
)
234+
data.question = question
235+
data.label = label
236+
data.desc = desc
237+
data_list.append(data)
238+
239+
self.save(data_list, path)

torch_geometric/nn/nlp/sentence_transformer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@ def encode(
5454
self,
5555
text: List[str],
5656
batch_size: Optional[int] = None,
57-
output_device: Optional[torch.device] = None,
57+
output_device: Optional[Union[torch.device, str]] = None,
5858
) -> Tensor:
59+
is_empty = len(text) == 0
60+
text = ['dummy'] if is_empty else text
61+
5962
batch_size = len(text) if batch_size is None else batch_size
6063

6164
embs: List[Tensor] = []
@@ -70,11 +73,13 @@ def encode(
7073
emb = self(
7174
input_ids=token.input_ids.to(self.device),
7275
attention_mask=token.attention_mask.to(self.device),
73-
).to(output_device or 'cpu')
76+
).to(output_device)
7477

7578
embs.append(emb)
7679

77-
return torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
80+
out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
81+
out = out[:0] if is_empty else out
82+
return out
7883

7984
def __repr__(self) -> str:
8085
return f'{self.__class__.__name__}(model_name={self.model_name})'

0 commit comments

Comments
 (0)