|
| 1 | +from typing import List, Optional |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import Tensor |
| 5 | + |
| 6 | +from torch_geometric.nn.models import GAT |
| 7 | +from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS |
| 8 | +from torch_geometric.utils import scatter |
| 9 | + |
| 10 | + |
| 11 | +class GRetriever(torch.nn.Module): |
| 12 | + r"""The G-Retriever model from the `"G-Retriever: Retrieval-Augmented |
| 13 | + Generation for Textual Graph Understanding and Question Answering" |
| 14 | + <https://arxiv.org/abs/2402.07630>`_ paper. |
| 15 | +
|
| 16 | + Args: |
| 17 | + llm (LLM): The LLM to use. |
| 18 | + gnn (torch.nn.Module): The GNN to use. |
| 19 | + use_lora (bool, optional): If set to :obj:`True`, will use LORA from |
| 20 | + :obj:`peft` for training the LLM, see |
| 21 | + `here <https://huggingface.co/docs/peft/en/index>`_ for details. |
| 22 | + (default: :obj:`False`) |
| 23 | + mlp_out_channels (int, optional): The size of each graph embedding |
| 24 | + after projection. (default: :obj:`4096`) |
| 25 | +
|
| 26 | + .. warning:: |
| 27 | + This module has been tested with the following HuggingFace models |
| 28 | +
|
| 29 | + * :obj:`llm_to_use="meta-llama/Llama-2-7b-chat-hf"` |
| 30 | + * :obj:`llm_to_use="google/gemma-7b"` |
| 31 | +
|
| 32 | + and may not work with other models. See other models at `HuggingFace |
| 33 | + Models <https://huggingface.co/models>`_ and let us know if you |
| 34 | + encounter any issues. |
| 35 | +
|
| 36 | + .. note:: |
| 37 | + For an example of using :class:`GRetriever`, see |
| 38 | + `examples/llm/g_retriever.py <https://github.com/pyg-team/ |
| 39 | + pytorch_geometric/blob/master/examples/llm/g_retriever.py>`_. |
| 40 | + """ |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + llm: LLM, |
| 44 | + gnn: torch.nn.Module, |
| 45 | + use_lora: bool = False, |
| 46 | + gnn_to_use=GAT, |
| 47 | + mlp_out_channels: int = 4096, |
| 48 | + ) -> None: |
| 49 | + super().__init__() |
| 50 | + |
| 51 | + self.llm = llm |
| 52 | + self.gnn = gnn.to(self.llm.device) |
| 53 | + |
| 54 | + self.word_embedding = self.llm.word_embedding |
| 55 | + self.llm_generator = self.llm.llm |
| 56 | + if use_lora: |
| 57 | + from peft import ( |
| 58 | + LoraConfig, |
| 59 | + get_peft_model, |
| 60 | + prepare_model_for_kbit_training, |
| 61 | + ) |
| 62 | + self.llm_generator = prepare_model_for_kbit_training( |
| 63 | + self.llm_generator) |
| 64 | + lora_r: int = 8 |
| 65 | + lora_alpha: int = 16 |
| 66 | + lora_dropout: float = 0.05 |
| 67 | + lora_target_modules = ['q_proj', 'v_proj'] |
| 68 | + config = LoraConfig( |
| 69 | + r=lora_r, |
| 70 | + lora_alpha=lora_alpha, |
| 71 | + target_modules=lora_target_modules, |
| 72 | + lora_dropout=lora_dropout, |
| 73 | + bias='none', |
| 74 | + task_type='CAUSAL_LM', |
| 75 | + ) |
| 76 | + self.llm_generator = get_peft_model(self.llm_generator, config) |
| 77 | + |
| 78 | + mlp_hidden_channels = self.gnn.out_channels |
| 79 | + self.projector = torch.nn.Sequential( |
| 80 | + torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels), |
| 81 | + torch.nn.Sigmoid(), |
| 82 | + torch.nn.Linear(mlp_hidden_channels, mlp_out_channels), |
| 83 | + ).to(self.llm.device) |
| 84 | + |
| 85 | + def encode( |
| 86 | + self, |
| 87 | + x: Tensor, |
| 88 | + edge_index: Tensor, |
| 89 | + batch: Tensor, |
| 90 | + edge_attr: Optional[Tensor], |
| 91 | + ) -> Tensor: |
| 92 | + x = x.to(self.llm.device) |
| 93 | + edge_index = edge_index.to(self.llm.device) |
| 94 | + if edge_attr is not None: |
| 95 | + edge_attr = edge_attr.to(self.llm.device) |
| 96 | + batch = batch.to(self.llm.device) |
| 97 | + |
| 98 | + out = self.gnn(x, edge_index, edge_attr=edge_attr) |
| 99 | + return scatter(out, batch, dim=0, reduce='mean') |
| 100 | + |
| 101 | + def forward( |
| 102 | + self, |
| 103 | + question: List[str], |
| 104 | + x: Tensor, |
| 105 | + edge_index: Tensor, |
| 106 | + batch: Tensor, |
| 107 | + label: List[str], |
| 108 | + edge_attr: Optional[Tensor] = None, |
| 109 | + additional_text_context: Optional[List[str]] = None, |
| 110 | + ): |
| 111 | + r"""The forward pass. |
| 112 | +
|
| 113 | + Args: |
| 114 | + question (List[str]): The questions/prompts. |
| 115 | + x (torch.Tensor): The input node features. |
| 116 | + edge_index (torch.Tensor): The edge indices. |
| 117 | + batch (torch.Tensor): The batch vector |
| 118 | + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns |
| 119 | + each element to a specific example. |
| 120 | + label (List[str]): The answers/labels. |
| 121 | + edge_attr (torch.Tensor, optional): The edge features (if supported |
| 122 | + by the GNN). (default: :obj:`None`) |
| 123 | + additional_text_context (List[str], optional): Additional context |
| 124 | + to give to the LLM, such as textified knowledge graphs. |
| 125 | + (default: :obj:`None`) |
| 126 | + """ |
| 127 | + x = self.encode(x, edge_index, batch, edge_attr) |
| 128 | + x = self.projector(x) |
| 129 | + xs = x.split(x.size(0), dim=0) |
| 130 | + |
| 131 | + ( |
| 132 | + inputs_embeds, |
| 133 | + attention_mask, |
| 134 | + label_input_ids, |
| 135 | + ) = self.llm._get_embeds(question, additional_text_context, xs, label) |
| 136 | + |
| 137 | + with self.llm.autocast_context: |
| 138 | + outputs = self.llm_generator( |
| 139 | + inputs_embeds=inputs_embeds, |
| 140 | + attention_mask=attention_mask, |
| 141 | + return_dict=True, |
| 142 | + labels=label_input_ids, |
| 143 | + ) |
| 144 | + |
| 145 | + return outputs.loss |
| 146 | + |
| 147 | + @torch.no_grad() |
| 148 | + def inference( |
| 149 | + self, |
| 150 | + question: List[str], |
| 151 | + x: Tensor, |
| 152 | + edge_index: Tensor, |
| 153 | + batch: Tensor, |
| 154 | + edge_attr: Optional[Tensor] = None, |
| 155 | + additional_text_context: Optional[List[str]] = None, |
| 156 | + max_out_tokens: Optional[int] = MAX_NEW_TOKENS, |
| 157 | + ): |
| 158 | + r"""The inference pass. |
| 159 | +
|
| 160 | + Args: |
| 161 | + question (List[str]): The questions/prompts. |
| 162 | + x (torch.Tensor): The input node features. |
| 163 | + edge_index (torch.Tensor): The edge indices. |
| 164 | + batch (torch.Tensor): The batch vector |
| 165 | + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns |
| 166 | + each element to a specific example. |
| 167 | + edge_attr (torch.Tensor, optional): The edge features (if supported |
| 168 | + by the GNN). (default: :obj:`None`) |
| 169 | + additional_text_context (List[str], optional): Additional context |
| 170 | + to give to the LLM, such as textified knowledge graphs. |
| 171 | + (default: :obj:`None`) |
| 172 | + max_out_tokens (int, optional): How many tokens for the LLM to |
| 173 | + generate. (default: :obj:`32`) |
| 174 | + """ |
| 175 | + x = self.encode(x, edge_index, batch, edge_attr) |
| 176 | + x = self.projector(x) |
| 177 | + xs = x.split(x.size(0), dim=0) |
| 178 | + |
| 179 | + inputs_embeds, attention_mask, _ = self.llm._get_embeds( |
| 180 | + question, additional_text_context, xs) |
| 181 | + |
| 182 | + bos_token = self.llm.tokenizer( |
| 183 | + BOS, |
| 184 | + add_special_tokens=False, |
| 185 | + ).input_ids[0] |
| 186 | + |
| 187 | + with self.llm.autocast_context: |
| 188 | + outputs = self.llm_generator.generate( |
| 189 | + inputs_embeds=inputs_embeds, |
| 190 | + max_new_tokens=max_out_tokens, |
| 191 | + attention_mask=attention_mask, |
| 192 | + bos_token_id=bos_token, |
| 193 | + use_cache=True # Important to set! |
| 194 | + ) |
| 195 | + |
| 196 | + return self.llm.tokenizer.batch_decode( |
| 197 | + outputs, |
| 198 | + skip_special_tokens=True, |
| 199 | + ) |
| 200 | + |
| 201 | + def __repr__(self) -> str: |
| 202 | + return (f'{self.__class__.__name__}(\n' |
| 203 | + f' llm={self.llm},\n' |
| 204 | + f' gnn={self.gnn},\n' |
| 205 | + f')') |
0 commit comments