|
| 1 | +# Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630 |
| 2 | +from typing import Any, Dict, List, Tuple, no_type_check |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +from torch import Tensor |
| 7 | +from tqdm import tqdm |
| 8 | + |
| 9 | +from torch_geometric.data import Data, InMemoryDataset |
| 10 | +from torch_geometric.nn.nlp import SentenceTransformer |
| 11 | + |
| 12 | + |
| 13 | +@no_type_check |
| 14 | +def retrieval_via_pcst( |
| 15 | + data: Data, |
| 16 | + q_emb: Tensor, |
| 17 | + textual_nodes: Any, |
| 18 | + textual_edges: Any, |
| 19 | + topk: int = 3, |
| 20 | + topk_e: int = 3, |
| 21 | + cost_e: float = 0.5, |
| 22 | +) -> Tuple[Data, str]: |
| 23 | + c = 0.01 |
| 24 | + if len(textual_nodes) == 0 or len(textual_edges) == 0: |
| 25 | + desc = textual_nodes.to_csv(index=False) + "\n" + textual_edges.to_csv( |
| 26 | + index=False, |
| 27 | + columns=["src", "edge_attr", "dst"], |
| 28 | + ) |
| 29 | + return data, desc |
| 30 | + |
| 31 | + from pcst_fast import pcst_fast |
| 32 | + |
| 33 | + root = -1 |
| 34 | + num_clusters = 1 |
| 35 | + pruning = 'gw' |
| 36 | + verbosity_level = 0 |
| 37 | + if topk > 0: |
| 38 | + n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x) |
| 39 | + topk = min(topk, data.num_nodes) |
| 40 | + _, topk_n_indices = torch.topk(n_prizes, topk, largest=True) |
| 41 | + |
| 42 | + n_prizes = torch.zeros_like(n_prizes) |
| 43 | + n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float() |
| 44 | + else: |
| 45 | + n_prizes = torch.zeros(data.num_nodes) |
| 46 | + |
| 47 | + if topk_e > 0: |
| 48 | + e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr) |
| 49 | + topk_e = min(topk_e, e_prizes.unique().size(0)) |
| 50 | + |
| 51 | + topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True) |
| 52 | + e_prizes[e_prizes < topk_e_values[-1]] = 0.0 |
| 53 | + last_topk_e_value = topk_e |
| 54 | + for k in range(topk_e): |
| 55 | + indices = e_prizes == topk_e_values[k] |
| 56 | + value = min((topk_e - k) / sum(indices), last_topk_e_value - c) |
| 57 | + e_prizes[indices] = value |
| 58 | + last_topk_e_value = value * (1 - c) |
| 59 | + # reduce the cost of the edges such that at least one edge is selected |
| 60 | + cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2)) |
| 61 | + else: |
| 62 | + e_prizes = torch.zeros(data.num_edges) |
| 63 | + |
| 64 | + costs = [] |
| 65 | + edges = [] |
| 66 | + virtual_n_prizes = [] |
| 67 | + virtual_edges = [] |
| 68 | + virtual_costs = [] |
| 69 | + mapping_n = {} |
| 70 | + mapping_e = {} |
| 71 | + for i, (src, dst) in enumerate(data.edge_index.t().numpy()): |
| 72 | + prize_e = e_prizes[i] |
| 73 | + if prize_e <= cost_e: |
| 74 | + mapping_e[len(edges)] = i |
| 75 | + edges.append((src, dst)) |
| 76 | + costs.append(cost_e - prize_e) |
| 77 | + else: |
| 78 | + virtual_node_id = data.num_nodes + len(virtual_n_prizes) |
| 79 | + mapping_n[virtual_node_id] = i |
| 80 | + virtual_edges.append((src, virtual_node_id)) |
| 81 | + virtual_edges.append((virtual_node_id, dst)) |
| 82 | + virtual_costs.append(0) |
| 83 | + virtual_costs.append(0) |
| 84 | + virtual_n_prizes.append(prize_e - cost_e) |
| 85 | + |
| 86 | + prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)]) |
| 87 | + num_edges = len(edges) |
| 88 | + if len(virtual_costs) > 0: |
| 89 | + costs = np.array(costs + virtual_costs) |
| 90 | + edges = np.array(edges + virtual_edges) |
| 91 | + |
| 92 | + vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters, |
| 93 | + pruning, verbosity_level) |
| 94 | + |
| 95 | + selected_nodes = vertices[vertices < data.num_nodes] |
| 96 | + selected_edges = [mapping_e[e] for e in edges if e < num_edges] |
| 97 | + virtual_vertices = vertices[vertices >= data.num_nodes] |
| 98 | + if len(virtual_vertices) > 0: |
| 99 | + virtual_vertices = vertices[vertices >= data.num_nodes] |
| 100 | + virtual_edges = [mapping_n[i] for i in virtual_vertices] |
| 101 | + selected_edges = np.array(selected_edges + virtual_edges) |
| 102 | + |
| 103 | + edge_index = data.edge_index[:, selected_edges] |
| 104 | + selected_nodes = np.unique( |
| 105 | + np.concatenate( |
| 106 | + [selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()])) |
| 107 | + |
| 108 | + n = textual_nodes.iloc[selected_nodes] |
| 109 | + e = textual_edges.iloc[selected_edges] |
| 110 | + desc = n.to_csv(index=False) + '\n' + e.to_csv( |
| 111 | + index=False, columns=['src', 'edge_attr', 'dst']) |
| 112 | + |
| 113 | + mapping = {n: i for i, n in enumerate(selected_nodes.tolist())} |
| 114 | + src = [mapping[i] for i in edge_index[0].tolist()] |
| 115 | + dst = [mapping[i] for i in edge_index[1].tolist()] |
| 116 | + |
| 117 | + data = Data( |
| 118 | + x=data.x[selected_nodes], |
| 119 | + edge_index=torch.tensor([src, dst]), |
| 120 | + edge_attr=data.edge_attr[selected_edges], |
| 121 | + ) |
| 122 | + |
| 123 | + return data, desc |
| 124 | + |
| 125 | + |
| 126 | +class WebQSPDataset(InMemoryDataset): |
| 127 | + r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse |
| 128 | + Labeling for Knowledge Base Question Answering" |
| 129 | + <https://aclanthology.org/P16-2033/>`_ paper. |
| 130 | +
|
| 131 | + Args: |
| 132 | + root (str): Root directory where the dataset should be saved. |
| 133 | + split (str, optional): If :obj:`"train"`, loads the training dataset. |
| 134 | + If :obj:`"val"`, loads the validation dataset. |
| 135 | + If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) |
| 136 | + force_reload (bool, optional): Whether to re-process the dataset. |
| 137 | + (default: :obj:`False`) |
| 138 | + """ |
| 139 | + def __init__( |
| 140 | + self, |
| 141 | + root: str, |
| 142 | + split: str = "train", |
| 143 | + force_reload: bool = False, |
| 144 | + ) -> None: |
| 145 | + super().__init__(root, force_reload=force_reload) |
| 146 | + |
| 147 | + if split not in {'train', 'val', 'test'}: |
| 148 | + raise ValueError(f"Invalid 'split' argument (got {split})") |
| 149 | + |
| 150 | + path = self.processed_paths[['train', 'val', 'test'].index(split)] |
| 151 | + self.load(path) |
| 152 | + |
| 153 | + @property |
| 154 | + def processed_file_names(self) -> List[str]: |
| 155 | + return ['train_data.pt', 'val_data.pt', 'test_data.pt'] |
| 156 | + |
| 157 | + def process(self) -> None: |
| 158 | + import datasets |
| 159 | + import pandas as pd |
| 160 | + |
| 161 | + datasets = datasets.load_dataset('rmanluo/RoG-webqsp') |
| 162 | + |
| 163 | + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 164 | + model_name = 'sentence-transformers/all-roberta-large-v1' |
| 165 | + model = SentenceTransformer(model_name).to(device) |
| 166 | + model.eval() |
| 167 | + |
| 168 | + for dataset, path in zip( |
| 169 | + [datasets['train'], datasets['validation'], datasets['test']], |
| 170 | + self.processed_paths, |
| 171 | + ): |
| 172 | + questions = [example["question"] for example in dataset] |
| 173 | + question_embs = model.encode( |
| 174 | + questions, |
| 175 | + batch_size=256, |
| 176 | + output_device='cpu', |
| 177 | + ) |
| 178 | + |
| 179 | + data_list = [] |
| 180 | + for i, example in enumerate(tqdm(dataset)): |
| 181 | + raw_nodes: Dict[str, int] = {} |
| 182 | + raw_edges = [] |
| 183 | + for tri in example["graph"]: |
| 184 | + h, r, t = tri |
| 185 | + h = h.lower() |
| 186 | + t = t.lower() |
| 187 | + if h not in raw_nodes: |
| 188 | + raw_nodes[h] = len(raw_nodes) |
| 189 | + if t not in raw_nodes: |
| 190 | + raw_nodes[t] = len(raw_nodes) |
| 191 | + raw_edges.append({ |
| 192 | + "src": raw_nodes[h], |
| 193 | + "edge_attr": r, |
| 194 | + "dst": raw_nodes[t] |
| 195 | + }) |
| 196 | + nodes = pd.DataFrame([{ |
| 197 | + "node_id": v, |
| 198 | + "node_attr": k, |
| 199 | + } for k, v in raw_nodes.items()]) |
| 200 | + edges = pd.DataFrame(raw_edges) |
| 201 | + |
| 202 | + nodes.node_attr = nodes.node_attr.fillna("") |
| 203 | + x = model.encode( |
| 204 | + nodes.node_attr.tolist(), |
| 205 | + batch_size=256, |
| 206 | + output_device='cpu', |
| 207 | + ) |
| 208 | + edge_attr = model.encode( |
| 209 | + edges.edge_attr.tolist(), |
| 210 | + batch_size=256, |
| 211 | + output_device='cpu', |
| 212 | + ) |
| 213 | + edge_index = torch.tensor([ |
| 214 | + edges.src.tolist(), |
| 215 | + edges.dst.tolist(), |
| 216 | + ]) |
| 217 | + |
| 218 | + question = f"Question: {example['question']}\nAnswer: " |
| 219 | + label = ('|').join(example['answer']).lower() |
| 220 | + data = Data( |
| 221 | + x=x, |
| 222 | + edge_index=edge_index, |
| 223 | + edge_attr=edge_attr, |
| 224 | + ) |
| 225 | + data, desc = retrieval_via_pcst( |
| 226 | + data, |
| 227 | + question_embs[i], |
| 228 | + nodes, |
| 229 | + edges, |
| 230 | + topk=3, |
| 231 | + topk_e=5, |
| 232 | + cost_e=0.5, |
| 233 | + ) |
| 234 | + data.question = question |
| 235 | + data.label = label |
| 236 | + data.desc = desc |
| 237 | + data_list.append(data) |
| 238 | + |
| 239 | + self.save(data_list, path) |
0 commit comments