Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c3fca68
test push 2
Nov 12, 2024
0b7e7e8
loaded and preprocessed data
Nov 13, 2024
cc73892
add CCS prediction to pred_all function
Nov 15, 2024
e728f2f
predict ccs values and calculate mse
Nov 20, 2024
6cb292f
analysed what affects error rates
Nov 26, 2024
78d096b
added RARE_ELEMENTS
Dec 13, 2024
06954d0
added element distribution property
Dec 13, 2024
3d2abfc
rewrote parts that were deleted because of the crash
Dec 13, 2024
733cfda
Begin implementing abstract Trainer class
ch4perone Nov 26, 2024
2e52e77
Rename debug_mode variable
ch4perone Dec 5, 2024
f863437
Implement Super/Subclass of GNN.Trainer and test in debug mode
ch4perone Dec 5, 2024
511b6b2
Debug new trainer class
ch4perone Dec 6, 2024
0a5c7ec
Clean up trainer classes and checkpoint system | Tested in debug mode
ch4perone Dec 6, 2024
2b02af5
Perform 100% run
ch4perone Dec 8, 2024
c8736d4
Perform 75% run
ch4perone Dec 8, 2024
b9cbbd0
Perform 50% run
ch4perone Dec 8, 2024
3917fad
Perform 25% run
ch4perone Dec 9, 2024
2e28c16
Perform 10% run & plot performances
ch4perone Dec 9, 2024
748ab3b
PropertyTrainer class defined to train model that does CCS predictions
Dec 17, 2024
3df8cd8
default value of with_rt and with_ccs set to True in forward func
Jan 8, 2025
cb7fdee
added element_distribution attribute
Jan 8, 2025
c3c0998
created PropertyTrainer class, can be usedd to train a model accordin…
Jan 8, 2025
1d99971
added test keys to generate self.test_data
Feb 18, 2025
8085562
created LinearModel and MLPModel, added test function to PropertyTrai…
Feb 18, 2025
5f3ddfe
added more test metrics to get_default_metrics func
Feb 19, 2025
a55628c
added weight to as_geometric_data function
Feb 19, 2025
961b8c5
fixed linear and mlp models, implemented more test statistics and vis…
Feb 19, 2025
42f84a5
added 'precursor_positive' attribute
Feb 21, 2025
3e7cc47
linear and mlp model takes precursor_positive as input, filtered_df a…
Feb 21, 2025
512de15
added Sulfur as a rare element
Feb 28, 2025
69b029a
added calc_abs_elem_distr function
Feb 28, 2025
306996a
created more plots for Test Results
Feb 28, 2025
187d9c9
added new features to linear and mlp model
Mar 31, 2025
0c6a7af
cleared all outputs
Mar 31, 2025
9c5ae7d
added extra plots; implemented LoRA wrapper and benchmarking
Jun 24, 2025
dbfcf79
Metadaten um spaeter zu visualisieren
Sep 23, 2025
ac33bf6
LoRA Wrapper debug, leafs freezen
Sep 23, 2025
c35476b
Merge main
Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fiora/GNN/GNNModules.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def get_graph_embedding(self, batch):
return pooling_func(X, batch["batch"])


def forward(self, batch, with_RT=False, with_CCS=False):
def forward(self, batch, with_RT=True, with_CCS=True):

# Embed node features
batch["node_embedding"] = self.node_embedding(batch["x"])
Expand Down
24 changes: 15 additions & 9 deletions fiora/GNN/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy, MetricTracker, MetricCollection, Precision, Recall, PrecisionRecallCurve, MeanSquaredError, MeanAbsoluteError, R2Score
from torchmetrics import Accuracy, MetricTracker, MetricCollection, Precision, Recall, PrecisionRecallCurve, MeanSquaredError, MeanAbsoluteError, R2Score, PearsonCorrCoef
from sklearn.model_selection import train_test_split
from typing import Literal, List, Dict, Any


class Trainer(ABC):
def __init__(self, data: Any, train_val_split: float=0.8, split_by_group: bool=False, only_training: bool=False,
train_keys: List[int]=[], val_keys: List[int]=[], seed: int=42, num_workers: int=0, device: str="cpu") -> None:
train_keys: List[int]=[], val_keys: List[int]=[], test_keys: List[int]=[], seed: int=42, num_workers: int=0, device: str="cpu") -> None:


self.only_training = only_training
self.num_workers = num_workers
Expand All @@ -19,29 +20,32 @@ def __init__(self, data: Any, train_val_split: float=0.8, split_by_group: bool=F
self.training_data = data
self.validation_data = Dataset()
elif split_by_group:
self._split_by_group(data, train_val_split, train_keys, val_keys, seed)
self._split_by_group(data, train_val_split, train_keys, val_keys, test_keys, seed)

else:
train_size = int(len(data) * train_val_split)
self.training_data, self.validation_data = torch.utils.data.random_split(
data, [train_size, len(data) - train_size],
generator=torch.Generator().manual_seed(seed)
)


def _split_by_group(self, data, train_val_split: float, train_keys: List[int], val_keys: List[int], seed: int):
def _split_by_group(self, data, train_val_split: float, train_keys: List[int], val_keys: List[int], test_keys: List[int], seed: int):
group_ids = [getattr(x, "group_id") for x in data]
keys = np.unique(group_ids)
if len(train_keys) > 0 and len(val_keys) > 0:
self.train_keys, self.val_keys = train_keys, val_keys
print("Using pre-set train/validation keys")
if len(train_keys) > 0 and len(val_keys) > 0 and len(test_keys) > 0:
self.train_keys, self.val_keys, self.test_keys = train_keys, val_keys, test_keys
print("Using pre-set train/validation/test keys")
else:
self.train_keys, self.val_keys = train_test_split(
keys, test_size=1 - train_val_split, random_state=seed
)
train_ids = np.where([group_id in self.train_keys for group_id in group_ids])[0]
val_ids = np.where([group_id in self.val_keys for group_id in group_ids])[0]
test_ids = np.where([group_id in self.test_keys for group_id in group_ids])[0]
self.training_data = torch.utils.data.Subset(data, train_ids)
self.validation_data = torch.utils.data.Subset(data, val_ids)
self.test_data = torch.utils.data.Subset(data, test_ids)


def _get_default_metrics(self, problem_type: Literal["classification", "regression", "softmax_regression"]):
metrics = {
Expand All @@ -53,7 +57,9 @@ def _get_default_metrics(self, problem_type: Literal["classification", "regressi
}) if problem_type=="classification" else MetricCollection(
{
'mse': MeanSquaredError(),
'mae': MeanAbsoluteError()
'mae': MeanAbsoluteError(),
'r2' : R2Score(),
'pearson' : PearsonCorrCoef()
})).to(self.device)
for data_split in ["train", "val", "masked_val", "test"]
}
Expand Down
95 changes: 94 additions & 1 deletion fiora/MOL/Metabolite.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def __init__(self, SMILES: str|None, InChI: str|None=None, id: int|None=None) ->
self.morganFingerCountOnes = self.morganFinger.GetNumOnBits()
self.id = id
self.loss_weight = 1.0
self.precursor_positive = None
self.ring_count = None
self.presence_rare_elements = None
self.elem_distr_vec = None

def __repr__(self):
return f"<Metabolite: {self.SMILES}>"
Expand All @@ -71,6 +75,22 @@ def __lt__(self, __o: object) -> bool: # TODO not tested!s
return False
return False

# Setter for Precursor Adduct
def set_precursor_positive(self, precursor_adduct):
self.precursor_positive = (precursor_adduct == "[M+H]+")

# Setter for Ring Count
def set_ring_count(self, ring_count):
self.ring_count = ring_count

# Setter for Presence of Rare Elements
def set_presence_rare_elements(self, presence_rare_elem):
self.presence_rare_elements = presence_rare_elem

# Setter for Elem Distribution Vector
def set_elem_distr_vec(self, elem_distr_vec):
self.elem_distr_vec = elem_distr_vec

def get_id(self):
return self.id

Expand Down Expand Up @@ -108,7 +128,51 @@ def draw(self, ax=plt):
# class-specific functions
def create_molecular_structure_graph(self):
self.Graph = mol_to_graph(self.MOL)

def calc_element_distribution(self):
element_distribution = {}
total_elements = len(self.node_elements)

for elem in self.node_elements:
if elem in element_distribution:
element_distribution[elem] += 1
else:
element_distribution[elem] = 1

# Convert counts to ratios
for elem in element_distribution:
element_distribution[elem] /= total_elements

return element_distribution

def calc_abs_elem_distr(self):
element_distribution = {}

for elem in self.node_elements:
if elem in element_distribution:
element_distribution[elem] += 1
else:
element_distribution[elem] = 1

return element_distribution

def calc_abs_elem_distr_vec(self, all_unique_elements):
"""Calculate the absolute element distribution vector with fixed length, for a metabolite."""
# Create a mapping of elements to indices, e.g. {"C": 0, "H": 1, "N": 2, "O": 3, "S": 4}
element_to_index = {element: i for i, element in enumerate(all_unique_elements)}

# Initialize a zero vector of fixed length
element_vector = np.zeros(len(all_unique_elements), dtype=int)

# Count occurrences of each element
unique_elements, counts = np.unique(self.node_elements, return_counts=True)

# Fill the vector using the mapping
for element, count in zip(unique_elements, counts):
element_vector[element_to_index[element]] = count

return element_vector


def compute_graph_attributes(self, node_encoder: AtomFeatureEncoder|None = None, bond_encoder: BondFeatureEncoder|None = None) -> None:

Expand All @@ -135,6 +199,7 @@ def compute_graph_attributes(self, node_encoder: AtomFeatureEncoder|None = None,
# Lists
self.atoms_in_order = [self.Graph.nodes[atom]['atom'] for atom in self.Graph.nodes()]
self.node_elements = [self.Graph.nodes[atom]['atom'].GetSymbol() for atom in self.Graph.nodes()]
self.element_distribution = self.calc_element_distribution()
self.edge_bond_names = [self.Graph[u][v]['bond_type'].name for u,v in self.edges_as_tuples]
if bond_encoder:
self.edge_bond_types = torch.tensor([bond_encoder.number_mapper["bond_type"][bond_name] for bond_name in self.edge_bond_names], dtype=torch.long)
Expand Down Expand Up @@ -337,7 +402,35 @@ def match_fragments_to_peaks(self, mz_fragments, int_list=None, mode_mapper=None
'ms_num_all_peaks': len(mz_fragments)
}

def as_geometric_data(self, with_labels=True):
def as_geometric_data(self, with_labels=True, ccs_only=False):
if ccs_only:
return Data(
x=self.node_features,
edge_index=self.edges.t().contiguous(),
edge_type=self.edge_bond_types,
edge_attr=self.bond_features,
static_graph_features=self.setup_features,
static_edge_features=self.setup_features_per_edge,
static_rt_features = self.rt_setup_features,
weight = self.ExactMolWeight,
precursor_positive = self.precursor_positive,
ring_count = self.ring_count,
presence_rare_elements = self.presence_rare_elements,
elem_distr_vec = self.elem_distr_vec,

# masks and groups
validation_mask=self.is_edge_not_in_ring.bool(),
group_id=self.id,

# additional information
is_node_aromatic=self.is_node_aromatic,
is_edge_aromatic=self.is_edge_aromatic,

ccs = self.ccs,
ccs_mask = self.ccs_mask,

smiles = self.SMILES
)
if with_labels:
return Data(
x=self.node_features,
Expand Down
2 changes: 1 addition & 1 deletion fiora/MOL/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"[M-3H]-": -1 * Descriptors.ExactMolWt(h_2) - 1 * Chem.Descriptors.ExactMolWt(h_plus),
}


RARE_ELEMENTS = ['Br', 'Cl', 'F', 'I', 'S']

PPM = 1/1000000
DEFAULT_PPM = 100 * PPM
Expand Down
9 changes: 5 additions & 4 deletions fiora/MS/SimulationFramework.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ def pred_all(self, df: pd.DataFrame, model: torch.nn.Module|None=None, attr_name
for i,d in df.iterrows():
metabolite = d["Metabolite"]
prediction = self.predict_metabolite_property(metabolite, model=model, as_batch=as_batch)
if self.with_RT:
setattr(metabolite, attr_name + "_pred", prediction["fragment_probs"])
if self.with_RT:
setattr(metabolite, "RT_pred", prediction["rt"].squeeze())
else:
setattr(metabolite, attr_name + "_pred", prediction["fragment_probs"])
if self.with_CCS:
setattr(metabolite, "CCS_pred", prediction["ccs"].squeeze())
setattr(metabolite, attr_name + "_pred", prediction["fragment_probs"])

return


Expand Down
Loading