Skip to content

Commit 6d9e850

Browse files
puririshi98pre-commit-ci[bot]akihironittarusty1s
authored
Add nn.models.GRetriever (#9480)
1. #9462 2. **->** #9480 3. #9481 4. #9167 --- breaking #9167 down further, focusing on G-retriever model this time --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: rusty1s <[email protected]>
1 parent df3df3b commit 6d9e850

File tree

6 files changed

+416
-117
lines changed

6 files changed

+416
-117
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10+
- Added the `GRetriever` model ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480))
1011
- Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))
1112
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
1213
- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))

test/nn/models/test_g_retriever.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
3+
from torch_geometric.nn import GAT, GRetriever
4+
from torch_geometric.nn.nlp import LLM
5+
from torch_geometric.testing import onlyFullTest, withPackage
6+
7+
8+
@onlyFullTest
9+
@withPackage('transformers', 'sentencepiece', 'accelerate')
10+
def test_g_retriever() -> None:
11+
llm = LLM(
12+
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
13+
num_params=1,
14+
dtype=torch.float16,
15+
)
16+
17+
gnn = GAT(
18+
in_channels=1024,
19+
out_channels=1024,
20+
hidden_channels=1024,
21+
num_layers=2,
22+
heads=4,
23+
norm='batch_norm',
24+
)
25+
26+
model = GRetriever(
27+
llm=llm,
28+
gnn=gnn,
29+
mlp_out_channels=2048,
30+
)
31+
assert str(model) == ('GRetriever(\n'
32+
' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n'
33+
' gnn=GAT(1024, 1024, num_layers=2),\n'
34+
')')
35+
36+
x = torch.randn(10, 1024)
37+
edge_index = torch.tensor([
38+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
39+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 0],
40+
])
41+
edge_attr = torch.randn(edge_index.size(1), 1024)
42+
batch = torch.zeros(x.size(0), dtype=torch.long)
43+
44+
question = ["Is PyG the best open-source GNN library?"]
45+
label = ["yes!"]
46+
47+
# Test train:
48+
loss = model(question, x, edge_index, batch, label, edge_attr)
49+
assert loss >= 0
50+
51+
# Test inference:
52+
pred = model.inference(question, x, edge_index, batch, edge_attr)
53+
assert len(pred) == 1

test/nn/nlp/test_llm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch import Tensor
33

4-
from torch_geometric.nn.nlp.llm import LLM
4+
from torch_geometric.nn.nlp import LLM
55
from torch_geometric.testing import onlyFullTest, withPackage
66

77

@@ -12,10 +12,11 @@ def test_llm() -> None:
1212
answer = ["yes!"]
1313

1414
model = LLM(
15-
model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1",
15+
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
1616
num_params=1,
1717
dtype=torch.float16,
1818
)
19+
assert str(model) == 'LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1)'
1920

2021
loss = model(question, answer)
2122
assert isinstance(loss, Tensor)

torch_geometric/nn/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .pmlp import PMLP
2929
from .neural_fingerprint import NeuralFingerprint
3030
from .visnet import ViSNet
31+
from .g_retriever import GRetriever
3132

3233
# Deprecated:
3334
from torch_geometric.explain.algorithm.captum import (to_captum_input,
@@ -75,4 +76,5 @@
7576
'PMLP',
7677
'NeuralFingerprint',
7778
'ViSNet',
79+
'GRetriever',
7880
]
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from typing import List, Optional
2+
3+
import torch
4+
from torch import Tensor
5+
6+
from torch_geometric.nn.models import GAT
7+
from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
8+
from torch_geometric.utils import scatter
9+
10+
11+
class GRetriever(torch.nn.Module):
12+
r"""The G-Retriever model from the `"G-Retriever: Retrieval-Augmented
13+
Generation for Textual Graph Understanding and Question Answering"
14+
<https://arxiv.org/abs/2402.07630>`_ paper.
15+
16+
Args:
17+
llm (LLM): The LLM to use.
18+
gnn (torch.nn.Module): The GNN to use.
19+
use_lora (bool, optional): If set to :obj:`True`, will use LORA from
20+
:obj:`peft` for training the LLM, see
21+
`here <https://huggingface.co/docs/peft/en/index>`_ for details.
22+
(default: :obj:`False`)
23+
mlp_out_channels (int, optional): The size of each graph embedding
24+
after projection. (default: :obj:`4096`)
25+
26+
.. warning::
27+
This module has been tested with the following HuggingFace models
28+
29+
* :obj:`llm_to_use="meta-llama/Llama-2-7b-chat-hf"`
30+
* :obj:`llm_to_use="google/gemma-7b"`
31+
32+
and may not work with other models. See other models at `HuggingFace
33+
Models <https://huggingface.co/models>`_ and let us know if you
34+
encounter any issues.
35+
36+
.. note::
37+
For an example of using :class:`GRetriever`, see
38+
`examples/llm/g_retriever.py <https://github.com/pyg-team/
39+
pytorch_geometric/blob/master/examples/llm/g_retriever.py>`_.
40+
"""
41+
def __init__(
42+
self,
43+
llm: LLM,
44+
gnn: torch.nn.Module,
45+
use_lora: bool = False,
46+
gnn_to_use=GAT,
47+
mlp_out_channels: int = 4096,
48+
) -> None:
49+
super().__init__()
50+
51+
self.llm = llm
52+
self.gnn = gnn.to(self.llm.device)
53+
54+
self.word_embedding = self.llm.word_embedding
55+
self.llm_generator = self.llm.llm
56+
if use_lora:
57+
from peft import (
58+
LoraConfig,
59+
get_peft_model,
60+
prepare_model_for_kbit_training,
61+
)
62+
self.llm_generator = prepare_model_for_kbit_training(
63+
self.llm_generator)
64+
lora_r: int = 8
65+
lora_alpha: int = 16
66+
lora_dropout: float = 0.05
67+
lora_target_modules = ['q_proj', 'v_proj']
68+
config = LoraConfig(
69+
r=lora_r,
70+
lora_alpha=lora_alpha,
71+
target_modules=lora_target_modules,
72+
lora_dropout=lora_dropout,
73+
bias='none',
74+
task_type='CAUSAL_LM',
75+
)
76+
self.llm_generator = get_peft_model(self.llm_generator, config)
77+
78+
mlp_hidden_channels = self.gnn.out_channels
79+
self.projector = torch.nn.Sequential(
80+
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
81+
torch.nn.Sigmoid(),
82+
torch.nn.Linear(mlp_hidden_channels, mlp_out_channels),
83+
).to(self.llm.device)
84+
85+
def encode(
86+
self,
87+
x: Tensor,
88+
edge_index: Tensor,
89+
batch: Tensor,
90+
edge_attr: Optional[Tensor],
91+
) -> Tensor:
92+
x = x.to(self.llm.device)
93+
edge_index = edge_index.to(self.llm.device)
94+
if edge_attr is not None:
95+
edge_attr = edge_attr.to(self.llm.device)
96+
batch = batch.to(self.llm.device)
97+
98+
out = self.gnn(x, edge_index, edge_attr=edge_attr)
99+
return scatter(out, batch, dim=0, reduce='mean')
100+
101+
def forward(
102+
self,
103+
question: List[str],
104+
x: Tensor,
105+
edge_index: Tensor,
106+
batch: Tensor,
107+
label: List[str],
108+
edge_attr: Optional[Tensor] = None,
109+
additional_text_context: Optional[List[str]] = None,
110+
):
111+
r"""The forward pass.
112+
113+
Args:
114+
question (List[str]): The questions/prompts.
115+
x (torch.Tensor): The input node features.
116+
edge_index (torch.Tensor): The edge indices.
117+
batch (torch.Tensor): The batch vector
118+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
119+
each element to a specific example.
120+
label (List[str]): The answers/labels.
121+
edge_attr (torch.Tensor, optional): The edge features (if supported
122+
by the GNN). (default: :obj:`None`)
123+
additional_text_context (List[str], optional): Additional context
124+
to give to the LLM, such as textified knowledge graphs.
125+
(default: :obj:`None`)
126+
"""
127+
x = self.encode(x, edge_index, batch, edge_attr)
128+
x = self.projector(x)
129+
xs = x.split(x.size(0), dim=0)
130+
131+
(
132+
inputs_embeds,
133+
attention_mask,
134+
label_input_ids,
135+
) = self.llm._get_embeds(question, additional_text_context, xs, label)
136+
137+
with self.llm.autocast_context:
138+
outputs = self.llm_generator(
139+
inputs_embeds=inputs_embeds,
140+
attention_mask=attention_mask,
141+
return_dict=True,
142+
labels=label_input_ids,
143+
)
144+
145+
return outputs.loss
146+
147+
@torch.no_grad()
148+
def inference(
149+
self,
150+
question: List[str],
151+
x: Tensor,
152+
edge_index: Tensor,
153+
batch: Tensor,
154+
edge_attr: Optional[Tensor] = None,
155+
additional_text_context: Optional[List[str]] = None,
156+
max_out_tokens: Optional[int] = MAX_NEW_TOKENS,
157+
):
158+
r"""The inference pass.
159+
160+
Args:
161+
question (List[str]): The questions/prompts.
162+
x (torch.Tensor): The input node features.
163+
edge_index (torch.Tensor): The edge indices.
164+
batch (torch.Tensor): The batch vector
165+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
166+
each element to a specific example.
167+
edge_attr (torch.Tensor, optional): The edge features (if supported
168+
by the GNN). (default: :obj:`None`)
169+
additional_text_context (List[str], optional): Additional context
170+
to give to the LLM, such as textified knowledge graphs.
171+
(default: :obj:`None`)
172+
max_out_tokens (int, optional): How many tokens for the LLM to
173+
generate. (default: :obj:`32`)
174+
"""
175+
x = self.encode(x, edge_index, batch, edge_attr)
176+
x = self.projector(x)
177+
xs = x.split(x.size(0), dim=0)
178+
179+
inputs_embeds, attention_mask, _ = self.llm._get_embeds(
180+
question, additional_text_context, xs)
181+
182+
bos_token = self.llm.tokenizer(
183+
BOS,
184+
add_special_tokens=False,
185+
).input_ids[0]
186+
187+
with self.llm.autocast_context:
188+
outputs = self.llm_generator.generate(
189+
inputs_embeds=inputs_embeds,
190+
max_new_tokens=max_out_tokens,
191+
attention_mask=attention_mask,
192+
bos_token_id=bos_token,
193+
use_cache=True # Important to set!
194+
)
195+
196+
return self.llm.tokenizer.batch_decode(
197+
outputs,
198+
skip_special_tokens=True,
199+
)
200+
201+
def __repr__(self) -> str:
202+
return (f'{self.__class__.__name__}(\n'
203+
f' llm={self.llm},\n'
204+
f' gnn={self.gnn},\n'
205+
f')')

0 commit comments

Comments
 (0)