forked from lversen/GNN-Molecules
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup_script.py
More file actions
299 lines (251 loc) · 9.41 KB
/
setup_script.py
File metadata and controls
299 lines (251 loc) · 9.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
"""
Improved setup script for TChemGNN - handles Windows compilation issues.
Replace your existing setup_script.py with this version.
"""
import subprocess
import sys
import os
import platform
import torch
def get_pytorch_info():
"""Get PyTorch version and CUDA info for wheel selection."""
try:
import torch
torch_version = torch.__version__.split('+')[0] # Remove +cu128 suffix
cuda_available = torch.cuda.is_available()
if cuda_available:
cuda_version = torch.version.cuda
if cuda_version:
cuda_suffix = f"cu{cuda_version.replace('.', '')}"
else:
cuda_suffix = "cpu"
else:
cuda_suffix = "cpu"
return torch_version, cuda_suffix
except ImportError:
return None, None
def install_pytorch_geometric_windows():
"""Install PyTorch Geometric components with pre-built wheels for Windows."""
print("\nDetected Windows - using pre-built wheels for PyTorch Geometric...")
# Get PyTorch info
torch_version, cuda_suffix = get_pytorch_info()
if torch_version is None:
print("PyTorch not found. Installing PyTorch first...")
subprocess.check_call([
sys.executable, '-m', 'pip', 'install',
'torch', 'torchvision', 'torchaudio',
'--index-url', 'https://download.pytorch.org/whl/cpu'
])
torch_version, cuda_suffix = get_pytorch_info()
print(f"PyTorch version: {torch_version}, CUDA: {cuda_suffix}")
# Construct wheel URL
base_url = f"https://data.pyg.org/whl/torch-{torch_version}+{cuda_suffix}.html"
# Install PyTorch Geometric components with pre-built wheels
pyg_packages = [
'torch-scatter',
'torch-sparse',
'torch-cluster',
'torch-spline-conv',
'torch-geometric'
]
for package in pyg_packages:
print(f"Installing {package}...")
try:
subprocess.check_call([
sys.executable, '-m', 'pip', 'install', package,
'-f', base_url
])
print(f"✓ {package} installed successfully")
except subprocess.CalledProcessError as e:
print(f"Warning: Failed to install {package} with wheels")
print(f"Error: {e}")
# Fallback: try conda if available
if package in ['torch-scatter', 'torch-sparse']:
try:
print(f"Trying conda for {package}...")
subprocess.check_call([
'conda', 'install', '-c', 'pyg', '-c', 'conda-forge',
package, '-y'
])
print(f"✓ {package} installed via conda")
except:
print(f"✗ Could not install {package}")
def install_pytorch_geometric_unix():
"""Install PyTorch Geometric for Unix-like systems."""
print("\nInstalling PyTorch Geometric for Unix/Linux/Mac...")
packages = [
'torch-geometric',
'torch-scatter',
'torch-sparse'
]
for package in packages:
try:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
print(f"✓ {package} installed successfully")
except subprocess.CalledProcessError:
print(f"Warning: Could not install {package}")
def install_requirements():
"""Install all required packages for TChemGNN."""
print("=" * 50)
print("Setting up TChemGNN Environment (Improved)")
print("=" * 50)
# Detect OS
is_windows = platform.system() == 'Windows'
# First, upgrade pip
print("\nUpgrading pip...")
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', 'pip'])
# Install PyTorch first (CPU version for compatibility)
print("\nInstalling PyTorch...")
try:
import torch
print(f"✓ PyTorch already installed: {torch.__version__}")
except ImportError:
subprocess.check_call([
sys.executable, '-m', 'pip', 'install',
'torch', 'torchvision', 'torchaudio',
'--index-url', 'https://download.pytorch.org/whl/cpu'
])
# Install PyTorch Geometric with OS-specific method
if is_windows:
install_pytorch_geometric_windows()
else:
install_pytorch_geometric_unix()
# Core scientific packages
core_requirements = [
'numpy',
'pandas',
'scikit-learn',
'matplotlib',
'seaborn',
'tqdm',
'pillow'
]
print("\nInstalling core scientific packages...")
for req in core_requirements:
print(f"Installing {req}...")
try:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', req])
except subprocess.CalledProcessError:
print(f"Warning: Could not install {req}")
# Install RDKit with special handling
print("\nInstalling RDKit...")
try:
import rdkit
print("✓ RDKit already installed")
except ImportError:
# Try pip first
try:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'rdkit'])
print("✓ RDKit installed via pip")
except subprocess.CalledProcessError:
# Try conda
try:
print("Pip failed, trying conda...")
subprocess.check_call(['conda', 'install', '-c', 'conda-forge', 'rdkit', '-y'])
print("✓ RDKit installed via conda")
except:
print("\n⚠ Could not install RDKit automatically.")
print("Please install RDKit manually:")
print(" conda install -c conda-forge rdkit")
print(" or")
print(" pip install rdkit")
print("\n" + "=" * 50)
print("Setup complete!")
print("=" * 50)
# Verify installation
print("\nVerifying installation...")
verify_installation()
def verify_installation():
"""Verify that all required packages are installed."""
required_modules = [
('torch', 'PyTorch'),
('torch_geometric', 'PyTorch Geometric'),
('torch_scatter', 'PyTorch Scatter'),
('torch_sparse', 'PyTorch Sparse'),
('rdkit', 'RDKit'),
('numpy', 'NumPy'),
('pandas', 'Pandas'),
('sklearn', 'Scikit-learn'),
('matplotlib', 'Matplotlib'),
('tqdm', 'tqdm')
]
all_good = True
critical_missing = []
for module_name, display_name in required_modules:
try:
__import__(module_name)
print(f"✓ {display_name} is installed")
except ImportError:
print(f"✗ {display_name} is NOT installed")
all_good = False
if module_name in ['torch', 'torch_geometric', 'rdkit']:
critical_missing.append(display_name)
if all_good:
print("\n✓ All dependencies are installed correctly!")
print("\nNext steps:")
print("1. Download datasets: python download_datasets.py")
print("2. Run experiments: python main.py --dataset esol")
print("3. Run all experiments: python run_all_experiments.py")
else:
print(f"\n⚠ Some dependencies are missing.")
if critical_missing:
print(f"Critical missing packages: {', '.join(critical_missing)}")
print("\nTroubleshooting suggestions:")
print("1. If on Windows, ensure you have Visual Studio Build Tools")
print("2. Try using conda environment instead:")
print(" conda create -n tchemgnn python=3.9")
print(" conda activate tchemgnn")
print(" conda install pytorch-geometric pytorch-sparse pytorch-scatter -c pyg")
print(" conda install rdkit -c conda-forge")
else:
print("Run the script again or install missing packages manually.")
def create_conda_environment_file():
"""Create environment.yml for conda users."""
env_content = """name: tchemgnn
channels:
- pytorch
- pyg
- conda-forge
- defaults
dependencies:
- python=3.9
- pytorch
- torchvision
- torchaudio
- cpuonly # Remove this line if you want GPU support
- pytorch-geometric
- pytorch-scatter
- pytorch-sparse
- rdkit
- numpy
- pandas
- scikit-learn
- matplotlib
- seaborn
- tqdm
- pillow
"""
with open('environment.yml', 'w') as f:
f.write(env_content)
print("Created environment.yml for conda users")
print("To use: conda env create -f environment.yml")
def main():
import argparse
parser = argparse.ArgumentParser(description='Setup TChemGNN environment')
parser.add_argument('--verify-only', action='store_true',
help='Only verify installation without installing packages')
parser.add_argument('--create-conda-env', action='store_true',
help='Create conda environment.yml file')
parser.add_argument('--force-conda', action='store_true',
help='Prefer conda over pip for installations')
args = parser.parse_args()
if args.create_conda_env:
create_conda_environment_file()
elif args.verify_only:
verify_installation()
else:
if args.force_conda:
print("Note: --force-conda not fully implemented yet")
install_requirements()
if __name__ == '__main__':
main()