Knowledge Graph + Multimodality: Taking Drug Prediction as an Example

Explore the revolutionary application of knowledge graphs and multimodal learning in the field of drug prediction.
Core content:
1. The core role of knowledge graphs in multimodal learning and its potential in drug prediction
2. How KG4MM integrates molecular images and text descriptions to improve the accuracy of drug interaction prediction
3. Use graph neural networks to connect knowledge graphs and multimodal data to achieve more explainable prediction results
introduction
Knowledge graphs have become an important tool for representing how different entities are related to each other. They encode information as nodes and edges, which visually show the connections between entities. Knowledge Graphs for Multimodal Learning (KG4MM) builds on this by using knowledge graphs to guide the process of learning from images and text. In KG4MM, the graph acts as a "map" that clearly marks the parts of each data type to focus on during training. This guidance helps the system focus on the most relevant features in images and the most informative words in text.
In the field of drug interaction prediction, KG4MM offers clear advantages. The graph structure integrates the molecular images and text descriptions of drugs in a unified framework. This unified perspective helps improve prediction accuracy because it captures both chemical structure and pharmacological background information. In addition, the knowledge graph creates a transparent path from input to output, making it easier to understand why the model makes a specific prediction.
This article will explain how KG4MM is used in practice to predict pairwise drug interactions. It will walk through the steps of building a knowledge graph and integrating molecular and textual information. Through concrete examples, it will illustrate how multimodal learning guided by knowledge graphs can solve practical challenges in medicine and healthcare research. The goal is to show how KG4MM can improve prediction accuracy and interpretability in real-world drug interaction tasks.
Methodology
The KG4MM approach places a knowledge graph at the heart of the entire process. The graph guides how each data type is processed and understood. In the drug interaction example, each drug node in the graph is associated with two pieces of information. The first is a molecular image derived from its SMILES formula, and the second is a textual description containing its class, functional groups, and other key details.
KG4MM is unique in that it leverages graph neural networks (GNNs) to connect graph structures and multimodal data. Based on the placement of a drug in the graph, the GNN determines which parts of its image and which words in its description are most worthy of attention. The edges in the graph—showing how the drug is related to proteins, diseases, and other drugs—help the network determine which visual and textual features are most important. In this way, the knowledge graph does more than just provide additional context; it actively guides the model to focus on the most informative data elements.
The strength of KG4MM lies in its ability to combine pattern recognition neural networks with explicit relationship graphs. GNNs have significant advantages in processing contextual data, so the model can be built on existing knowledge of drug interactions and biochemical properties. This guided learning not only improves prediction accuracy, but also produces clear and interpretable results by highlighting the specific graph connections that affect each prediction.
Implementation Overview
The system is built around a core knowledge graph that integrates all components. This graph captures directed relationships between drugs, proteins, and diseases, such as a drug "binds to" a protein, "inhibits" a target, or "treats" a disease. By placing the graph at the core of the design, every step of the process relies on its structured medical knowledge graph.
To prepare the data, the system associates two representations to each drug node: a molecular image and a textual description. The first is a molecular image generated from its SMILES molecular formula using RDKit. The second is a textual description that summarizes the drug class, functional groups, and other relevant details. Both the image and the text are directly connected to the corresponding drug node in the graph, ensuring that the visual and linguistic features are consistent with the underlying knowledge structure.
Modeling the graph itself relies on graph convolutional networks (GCNs). These networks learn from the position of each node and its connections in the graph, creating embeddings that encode how drugs, proteins, and diseases are related to each other. Meanwhile, multimodal encoders convert images and text into feature vectors: a ResNet processes molecular images, while a BERT model converts text descriptions.
Finally, a graph attention network (GAT) fuses the graph embedding with visual and textual features. The attention mechanism exploits the graph structure to weight the most important features from each modality. The combined representation is then fed into a prediction module that determines whether two drugs will interact. At the same time, the attention weights reveal which graph connections, image regions, or text elements contribute most to the model's decision, providing a clear explanation for each prediction.
Detailed implementation
This step ensures that all necessary deep learning, graph processing, and cheminformatics packages are available in the environment. The implementation begins by installing and importing the required libraries. It installs PyTorch and torchvision for neural networks, HuggingFace Transformers for text encoding, NetworkX and torch-geometric for graph manipulation, RDKit and OpenBabel for working with molecular structures, and supporting libraries such as pandas, NumPy, and Matplotlib. Once installed, import the required libraries and modules for use in subsequent units.
# install necessary packages
!pip install torch torchvision transformers networkx spacy rdflib rdkit pillow scikit-learn matplotlib seaborn torch-geometric
# pip did not work
!apt-get install openbabel
!pip install openbabel-wheel# import libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset , DataLoader
import torchvision. models as models
import torchvision. transforms as transforms
from transformers import BertModel , BertTokenizer
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import json
import os
from rdkit import Chem
from rdkit.Chem import Draw
from PIL import Image
import io
import base64
from openbabel import openbabel
from torch_geometric.data import Data
import torch_geometric. nn as geom_nn
Data preparation
First, a directory is created to store drug images. Then a simplified DrugBank sample is downloaded from a public repository and saved as a TSV file. This file is loaded into a pandas DataFrame to generate a table containing each drug's unique identifier, name, InChI string for molecular structure, and descriptive metadata such as category and group. This structured dataset lays the foundation for generating visual and textual representations in subsequent steps.
# create directory for data storage
!mkdir -p data/drug _images# download DrugBank sample data (simplified version for demonstration)
!wget -q -O data/drugbank_ sample.tsv https://raw.githubusercontent.com/dhimmel/drugbank/gh-pages/data/drugbank-slim.tsv# load DrugBank data
drug _df = pd.read_ csv('data/drugbank _sample.tsv', sep='\t')
Molecular structure conversion
The molecular structures are provided in InChI format. These representations need to be converted to SMILES format using OpenBabel. SMILES stands for Simplified Molecular Input Line Entry System and provides a concise, text-based way to describe chemical structures. SMILES strings are compatible with tools such as RDKit, which can generate molecular images from SMILES strings. The following code shows how to do this conversion.
# create a SMILES column by converting InChI to SMILES
def inchi _to_ smiles _openbabel(inchi_ str):
try: # create Open Babel OBMol object from InChI
obConversion = openbabel.OBConversion()
obConversion.SetInAndOutFormats("inchi", "smiles")
mol = openbabel.OBMol() # convert InChI to molecule # also remove extra newlines or spaces
if obConversion.ReadString(mol, inchi _str):
return obConversion.WriteString(mol).strip()
else:
return None
except Exception as e:
print(f"Error converting InChI to SMILES: {inchi_ str}. Error: {e}")
return None# apply the conversion to each InChI in the dataframe
drug _df['smiles'] = drug_ df['inchi'].apply(inchi _to_ smiles _openbabel)
Knowledge graph construction
The system builds a directed medical knowledge graph to capture the relationships between drugs, proteins, and diseases. Each node represents a drug, protein, or disease, and each edge encodes an interaction, such as binds_to, inhibits, or treats. These connections store expert knowledge about how drugs affect biological targets and diseases.
The graph serves as a source of structured relational information that the model leverages alongside image and text features. By explicitly representing domain knowledge, the graph enhances predictive accuracy and the ability to explain why two drugs might interact.
# initialize a medical knowledge graph
medical _kg = nx.DiGraph()# extract drug entities from DrugBank
# limit to 50 drugs for demo
drug_ entities = drug _df['name'].dropna().unique().tolist()[:50]# create drug nodes
for drug in drug_ entities:
medical _kg.add_ node(drug, type='drug')# add biomedical entities (proteins, targets, diseases)
protein _entities = ["Cytochrome P450", "Albumin", "P-glycoprotein", "GABA Receptor",
"Serotonin Receptor", "Beta-Adrenergic Receptor", "ACE", "HMGCR"]
disease_ entities = ["Hypertension", "Diabetes", "Depression", "Epilepsy",
"Asthma", "Rheumatoid Arthritis", "Parkinson's Disease"]for protein in protein _entities:
medical_ kg.add _node(protein, type='protein')for disease in disease_ entities:
medical _kg.add_ node(disease, type='disease')# add relationships (based on common drug mechanisms and interactions)
# drug-protein relationships
drug _protein_ relations = [
("Warfarin", "binds_to ", "Albumin"),
("Atorvastatin", "inhibits", "HMGCR"),
("Diazepam", "modulates", "GABA Receptor"),
("Fluoxetine", "inhibits", "Serotonin Receptor"),
("Phenytoin", "induces", "Cytochrome P450"),
("Metoprolol", "blocks", "Beta-Adrenergic Receptor"),
("Lisinopril", "inhibits", "ACE"),
("Rifampin", "induces", "P-glycoprotein"),
("Carbamazepine", "induces", "Cytochrome P450"),
("Verapamil", "inhibits", "P-glycoprotein")
]# drug-disease relationships
drug_disease_relations = [
("Lisinopril", "treats", "Hypertension"),
("Metformin", "treats", "Diabetes"),
("Fluoxetine", "treats", "Depression"),
("Phenytoin", "treats", "Epilepsy"),
("Albuterol", "treats", "Asthma"),
("Methotrexate", "treats", "Rheumatoid Arthritis"),
("Levodopa", "treats", "Parkinson's Disease")
]# known drug-drug interactions (based on actual medical knowledge)
drug_drug_interactions = [
("Goserelin", "interacts_ with", "Desmopressin", "increases _anticoagulant_ effect"),
("Goserelin", "interacts _with", "Cetrorelix", "increases_ bleeding _risk"),
("Cyclosporine", "interacts_ with", "Felypressin", "decreases _efficacy"),
("Octreotide", "interacts_ with", "Cyanocobalamin", "increases _hypoglycemia_ risk"),
("Tetrahydrofolic acid", "interacts _with", "L-Histidine", "increases_ statin _concentration"),
("S-Adenosylmethionine", "interacts_ with", "Pyruvic acid", "decreases _efficacy"),
("L-Phenylalanine", "interacts_ with", "Biotin", "increases _sedation"),
("Choline", "interacts_ with", "L-Lysine", "decreases _efficacy")
]# add all relationships to the knowledge graph
for s, r, o in drug_ protein _relations:
if s in medical_ kg and o in medical _kg:
medical_ kg.add _edge(s, o, relation=r)for s, r, o in drug_ disease _relations:
if s in medical_ kg and o in medical _kg:
medical_ kg.add _edge(s, o, relation=r)for s, r, o, mechanism in drug_ drug _interactions:
if s in medical_ kg and o in medical _kg:
medical_ kg.add _edge(s, o, relation=r, mechanism=mechanism)
Multimodal data processing
Each drug is represented by three complementary data types. First, its SMILES representation is converted into a molecule object and rendered into an image using RDKit.
# function to generate molecular structure images using RDKit
def generate_molecule_image (smiles_string, size=( 224 , 224 )):
try :
mol = Chem . MolFromSmiles (smiles_string)
if mol :
img = Draw . MolToImage (mol, size=size)
return img
else :
return None
except :
return None
Second, descriptive text was constructed by combining the drug name, class, group information, and any available metadata.
# function to create text description for drugs combining various information
def create_drug_description (row):
description = f "Drug name: {row['name']}. " if pd. notna (row. get ( 'category' )):
description += f "Category: {row['category']}. " if pd. notna (row. get ( 'groups' )):
description += f "Groups: {row['groups']}. " if pd. notna (row. get ( 'description' )):
description += f "Description: {row['description']}"
Third, the graph is embedded in it. Specifically, each node and relationship is initially a random vector, and then iteratively adjusted so that for each real connection, the vector of an entity plus the vector of its relationship makes it close to the vector of the entity it is connected to. After many iterations, this forms an embedding space where connected elements are naturally clustered and the direction of the relationship is encoded by the relationship vector. The result is a pair of lookup tables that map each node and relationship to compact, trainable coordinates that reflect the complete structure of the knowledge graph.
# convert NetworkX graph to PyG graph for modern graph neural network processing
def convert _nx_ to _pyg(nx_ graph): # create node mappings
node _to_ idx = {node: i for i, node in enumerate(nx _graph.nodes())} # create edge lists
src_nodes = []
dst_nodes = []
edge_types = []
edge _type_ to _idx = {} for u, v, data in nx_ graph.edges(data=True):
relation = data.get('relation', 'unknown')
if relation not in edge _type_ to _idx:
edge_ type _to_ idx[relation] = len(edge _type_ to _idx)
src_ nodes.append(node _to_ idx[u])
dst _nodes.append(node_ to _idx[v])
edge_ types.append(edge _type_ to _idx[relation]) # create PyG graph
edge_ index = torch.tensor([src _nodes, dst_ nodes], dtype=torch.long)
edge _type = torch.tensor(edge_ types, dtype=torch.long) # create node features
node_types = []
for node in nx_ graph.nodes():
node _type = nx_ graph.nodes[node].get('type', 'unknown')
node _types.append(node_ type) # one-hot encode node types
unique _node_ types = sorted(set(node _types))
node_ type _to_ idx = {nt: i for i, nt in enumerate(unique _node_ types)}
node _type_ features = torch.zeros(len(node _types), len(unique_ node _types))
for i, nt in enumerate(node_ types):
node _type_ features[i, node _type_ to _idx[nt]] = 1.0 # create PyG Data object with the proper attributes
g = Data(
edge_index =edge_index ,
edge_ type=edge _type,
x=node_ type _features # node features in PyG are stored in 'x'
) # create reverse mappings for later use
idx_ to _node = {idx: node for node, idx in node_ to _idx.items()}
idx_ to _edge_ type = {idx: edge _type for edge_ type, idx in edge _type_ to _idx.items()} return g, node_ to _idx, idx_ to _node, edge_ type _to_ idx, idx _to_ edge _type# convert medical_ kg to DGL graph
pyg _graph, node_ to _idx, idx_ to _node, edge_ type _to_ idx, idx _to_ edge _type = convert_ nx _to_ pyg(medical _kg)
These visual, textual, and structured representations are saved so that the model can fuse them for interaction prediction.
# process drug data to create multi-modal representations
drug _data = []for idx, row in drug_ df.iterrows():
if row['name'] in drug _entities and pd.notna(row.get('smiles')): # generate molecule image
img = generate_ molecule _image(row['smiles']) if img:
img_ path = f"data/drug _images/{row['drugbank_ id']}.png"
img.save(img _path) # Create text description
description = create_ drug _description(row) # Store drug information
drug_ data.append({
'id': row['drugbank _id'],
'name': row['name'],
'smiles': row['smiles'],
'description': description,
'image_ path': img _path
})drug_ data _df = pd.DataFrame(drug_ data)
Encoder Development
MultimodalNodeEncoder creates a single encoder that converts each node's molecule image and its textual summary into compatible feature vectors. First, it applies a deep convolutional network to the raw chemical graph to refine it into a compact visual fingerprint. At the same time, it processes the drug description through a pre-trained language model to extract a semantic summary. The outputs of both are then mapped into the same vector space so that visual and textual signals can be meaningfully combined under the guidance of the knowledge graph structure.
# processes visual and textual features for nodes
class MultimodalNodeEncoder (nn. Module ): def __init__ (self, output_dim= 128 ):
super ( MultimodalNodeEncoder , self). __init__ ()
# image encoder ( ResNet )
resnet = models. resnet18 (pretrained= True )
# remove the final fully connected layer to get 512 features
self. image_encoder = nn. Sequential (* list (resnet. children ())[:- 1 ])
self. image_projection = nn. Linear ( 512 , output_dim) # text encoder ( BERT )
self. tokenizer = BertTokenizer . from_pretrained ( 'bert-base-uncased' )
self. text_encoder = BertModel . from_pretrained ( 'bert-base-uncased' )
# BERT base outputs 768 features
self. text_projection = nn. Linear ( 768 , output_dim) def forward (self, image, text):
# image encoding
img_features = self. image_encoder (image). squeeze (- 1 ). squeeze (- 1 )
img_features = self. image_projection (img_features) # text encoding
encoded_input = self. tokenizer (text, padding= True , truncation= True ,
return_tensors= "pt" , max_length= 128 )
# move encoded input to the same device as the image
input_ids = encoded_input[ 'input_ids' ]. to (image. device )
attention_mask = encoded_input[ 'attention_mask' ]. to (image. device ) text_outputs = self. text_encoder (input_ids=input_ids,
attention_mask=attention_mask)
# use the [ CLS ] token embedding (first token)
text_features = text_outputs. last_hidden_state [:, 0 , :]
text_features = self. text_projection (text_features) return img_features, text_features
Model Integration
The KG-guided multimodal model fuses the visual, textual, and type embeddings of each node and predicts drug-drug interactions under the guidance of the knowledge graph. It first projects the image and textual outputs of each node into a shared space and assigns a separate embedding to its node type. These embeddings are then propagated across the graph, allowing each node to fuse its own features with the signals collected from its neighbors. The attention step reweights these fused features based on the strength and type of the connection. When evaluating a pair of drugs, the model takes their refined node embeddings, combines them through concatenation, element-wise products, and differences, and then feeds the results into the prediction head to produce the probability of interaction. By letting the topology of the graph determine how multimodal signals are fused, the predictions generated by the model are both accurate and directly traceable to the underlying network structure.
# define KG -guided Multimodal Model
class KGGuidedMultimodalModel (nn. Module ): def __init__ (self, pyg_graph, num_node_types, num_edge_types, node_to_idx, idx_to_node, hidden_dim= 128 ):
super ( KGGuidedMultimodalModel , self). __init__ ()
self.pyg_graph = pyg_graph
self.node_to_idx = node_to_idx
self.idx_to_node = idx_to_node
self. hidden_dim = hidden_dim # multimodal encoder for processing node-associated data
self. multimodal_encoder = MultimodalNodeEncoder (output_dim=hidden_dim) # node type embeddings
self. node_type_embedding = nn. Embedding (num_node_types, hidden_dim) # Graph Neural Network layers for knowledge graph processing ( PyG GCNConv instead of dglnn. GraphConv )
self. gnn_layers = nn. ModuleList ([
geom_nn.GCNConv ( hidden_dim, hidden_dim),
geom_nn.GCNConv ( hidden_dim, hidden_dim),
]) # Graph Attention Network for integrating multimodal features with graph structure ( PyG GATConv )
# explicitly set output dimension so total output is hidden_dim (not hidden_dim * num_heads)
self.gat_layer = geom_nn.GATConv ( hidden_dim , hidden_dim // 4, heads=4) # relation prediction layer - updated to match the actual input dimensions we'll have
self. relation_prediction = nn. Sequential (
nn. Linear (hidden_dim * 4 , hidden_dim * 2 ),
nn.ReLU () ,
nn.Dropout ( 0.3 ) ,
nn. Linear (hidden_dim * 2 , hidden_dim),
nn.ReLU () ,
nn.Dropout ( 0.3 ) ,
nn.Linear ( hidden_dim , 1 )
) def get_node_representation (self, node_name, image= None , text= None ):
if node_name not in self. node_to_idx :
# handle unknown nodes
return torch. zeros (self. hidden_dim , device=self. pyg_graph . edge_index . device ) node_idx = self. node_to_idx [node_name] # get node type features - use x instead of ndata[ 'type' ]
node_type_feat = self. pyg_graph . x [node_idx]
node_type_embedding = self. node_type_embedding (torch. argmax (node_type_feat)) # if multimodal data is provided, process it
if image is not None and text is not None :
img_feat, text_feat = self. multimodal_encoder (image, text) # squeeze out the batch dimension to match shapes
img_feat = img_feat.squeeze ( 0 )
text_feat = text_feat. squeeze ( 0 ) # knowledge graph structure guides how multimodal features are integrated
# use node_type_embedding as a query to attend to multimodal features
attention_weights = torch.softmax (
torch.matmul (
torch.stack ( [img_feat, text_feat, node_type_embedding]),
node_type_embedding
),
dim = 0
) # weighted combination of features
combined_feat = (
attention_weights[ 0 ] * img_feat +
attention_weights[ 1 ] * text_feat +
attention_weights[ 2 ] * node_type_embedding
) return combined_feat
else :
# for nodes without multimodal data, just use type embedding
return node_type_embedding def forward (self, drug1_image, drug1_text, drug1_name, drug2_image, drug2_text, drug2_name):
# process the entire graph first
device = self. pyg_graph . edge_index . device
x = torch. zeros ((self. pyg_graph . x . size ( 0 ), self. hidden_dim ), device=device) # initialize known node features
for i, node_name in enumerate ([drug1_name, drug2_name]):
if node_name in self. node_to_idx :
node_idx = self. node_to_idx [node_name]
if i == 0 :
x[node_idx] = self. get_node_representation (node_name, drug1_image, drug1_text)
else :
x[node_idx] = self. get_node_representation (node_name, drug2_image, drug2_text) # apply graph convolutions to propagate information - PyG style
edge_index = self.pyg_graph.edge_index
for layer in self. gnn_layers :
x = layer (x, edge_index)
x = torch.relu ( x) # apply graph attention to integrate features - PyG style
x = self.gat_layer ( x, edge_index) # get final representations for the two drugs
drug1_idx = self. node_to_idx . get (drug1_name, 0 )
drug2_idx = self. node_to_idx . get (drug2_name, 0 ) drug1_repr = x[drug1_idx]
drug2_repr = x[drug2_idx] # predict interaction
# concatenate representations in multiple ways to capture relationship
concat_repr = torch.cat ( [
drug1_repr,
drug2_repr,
drug1_repr * drug2_repr,
torch. abs (drug1_repr - drug2_repr)
], dim= 0 ) interaction_prob = torch. sigmoid (self. relation_prediction (concat_repr. unsqueeze ( 0 )). squeeze ())
return interaction_prob
Knowledge Extraction
The algorithm constructs a focused subgraph based on how the two drugs are related in the larger graph. It first looks for any direct edges between the two drugs in the graph and records their properties if found. Next, it identifies proteins or diseases associated with both drugs, revealing shared mechanisms. Finally, it traces all simple paths of no more than a given length to reveal indirect connections through intermediate nodes. The result is a compact network of key nodes and edges that captures the domain knowledge behind the predicted interactions and guides downstream layers to emphasize the most relevant multimodal features.
# function to retrieve knowledge subgraph relevant to a drug pair
def retrieve_knowledge_subgraph (graph, drug1, drug2, max_path_length= 3 ):
relevant_knowledge = {
'direct_interaction' : None ,
'common_targets' : [],
'paths' : []
} # check for direct interaction
if graph. has_edge (drug1, drug2):
edge_data = graph. get_edge_data (drug1, drug2)
relevant_knowledge[ 'direct_interaction' ] = edge_data # find common targets (proteins, diseases)
drug1_neighbors = set (graph. neighbors (drug1)) if drug1 in graph else set ()
drug2_neighbors = set (graph. neighbors (drug2)) if drug2 in graph else set () common_neighbors = drug1_neighbors. intersection (drug2_neighbors)
for common_node in common_neighbors :
node_type = graph. nodes [common_node]. get ( 'type' , '' )
if node_type == 'protein' or node_type == 'disease' :
relevant_knowledge[ 'common_targets' ]. append (common_node) # find paths between drugs (up to max_path_length)
try :
paths = list (nx. all_simple_paths (graph, drug1, drug2, cutoff=max_path_length))
relevant_knowledge[ 'paths' ] = paths
except (nx. NetworkXError , nx. NodeNotFound ):
# Handle cases where paths do not exist or nodes are not in graph
pass return relevant_knowledge
Batch data processing
The function is responsible for preparing each training batch, first discarding all samples that failed to load or were incomplete. It then stacks all valid molecule images into a batched tensor for the drug pair, while collecting their corresponding text summaries and identifiers into parallel lists. Interaction labels are similarly combined into a single tensor. By returning a unified dictionary containing these batched components - or empty placeholders if the remaining samples are invalid - it ensures that the model always receives well-structured, homogeneous input, despite the heterogeneity and occasional missingness of the underlying data.
# custom collate function to handle None values
def custom_collate_fn (batch):
# filter out None values
batch = [item for item in batch if item is not None ] # return empty batch if all items were None
if len (batch) == 0 :
return {
'drug1_img' : torch. tensor ([]),
'drug1_text' : [],
'drug1_name' : [],
'drug2_img' : torch. tensor ([]),
'drug2_text' : [],
'drug2_name' : [],
'label' : torch.tensor ([] )
} # process non- None items
drug1_imgs = torch. stack ([item[ 'drug1_img' ] for item in batch])
drug1_texts = [item[ 'drug1_text' ] for item in batch]
drug1_names = [item[ 'drug1_name' ] for item in batch] drug2_imgs = torch. stack ([item[ 'drug2_img' ] for item in batch])
drug2_texts = [item[ 'drug2_text' ] for item in batch]
drug2_names = [item[ 'drug2_name' ] for item in batch] labels = torch. stack ([item[ 'label' ] for item in batch]) return {
'drug1_img' : drug1_imgs,
'drug1_text' : drug1_texts,
'drug1_name' : drug1_names,
'drug2_img' : drug2_imgs,
'drug2_text' : drug2_texts,
'drug2_name' : drug2_names,
'label' : labels
}
Dataset preparation
The training examples combine all true interactions and a set of matched random non-interacting pairs. The process first extracts all known drug interaction pairs, then samples an equal number of negative pairs to balance the dataset. When an example is obtained, the molecular images and text summaries for each drug are loaded and preprocessed, skipping any pairs with missing data, and generating records containing images, descriptions, names, and a binary label for both drugs. By pairing positive and negative examples, applying consistent image transformations, and robustly handling missing data, the dataset provides reliable, ready-to-use batches for training interaction prediction models.
# define dataset for DDI prediction
class DDIDataset ( Dataset ): def __init__ (self, drug_data_df, drug_drug_interactions, medical_kg, node_to_idx, transform= None ):
self.drug_data = drug_data_df
self. drug_name_to_idx = {row[ 'name' ]: i for i, row in drug_data_df. iterrows ()}
self.node_to_idx = node_to_idx
self. transform = transform or transforms. Compose ([
transforms.Resize (( 224 , 224 )) ,
transforms.ToTensor () ,
transforms.Normalize (mean=[ 0.485 , 0.456 , 0.406 ], std=[ 0.229 , 0.224 , 0.225 ] )
]) # create pairs of drugs with interaction labels
self.pairs = [ ]
drug_names = list (self. drug_name_to_idx . keys ()) # positive samples (known interactions)
for interaction in drug_drug_interactions :
drug1, _, drug2, _ = interaction
if drug1 in drug_names and drug2 in drug_names :
# 1 for positive interaction
self. pairs . append ((drug1, drug2, 1 ))
positive_pairs = set ((d1, d2) for d1, d2, _ in self. pairs ) # generate some negative samples
np.random.seed ( 42 )
neg_count = 0
max_neg = len ( self. pairs )
while neg_count < max_neg :
i, j = np. random . choice ( len (drug_names), 2 , replace= False )
drug1, drug2 = drug_names[i], drug_names[j]
if (drug1, drug2) not in positive_pairs and (drug2, drug1) not in positive_pairs :
# 0 for negative interaction
self. pairs . append ((drug1, drug2, 0 ))
neg_count += 1 def __len__ (self):
return len (self. pairs ) def __getitem__ (self, idx):
try :
drug1_name, drug2_name, label = self. pairs [idx] # get drug1 data
drug1_idx = self. drug_name_to_idx [drug1_name]
drug1_data = self. drug_data . iloc [drug1_idx] # load drug1 image with error handling
try :
drug1_img = Image . open (drug1_data[ 'image_path' ]). convert ( 'RGB' )
drug1_img = self. transform (drug1_img)
except Exception as e :
print (f "Error loading drug1 image for {drug1_name}: {str(e)}" )
return None drug1_text = drug1_data[ 'description' ] # get drug2 data
drug2_idx = self. drug_name_to_idx [drug2_name]
drug2_data = self. drug_data . iloc [drug2_idx] # load drug2 image with error handling
try :
drug2_img = Image . open (drug2_data[ 'image_path' ]). convert ( 'RGB' )
drug2_img = self. transform (drug2_img)
except Exception as e :
print (f "Error loading drug2 image for {drug2_name}: {str(e)}" )
return None drug2_text = drug2_data[ 'description' ] return {
'drug1_img' : drug1_img,
'drug1_text' : drug1_text,
'drug1_name' : drug1_name,
'drug2_img' : drug2_img,
'drug2_text' : drug2_text,
'drug2_name' : drug2_name,
'label' : torch. tensor (label, dtype=torch. float32 )
}
except Exception as e :
print (f "Error in __getitem__ for index {idx}: {str(e)}" )
return None
Model Training
At the beginning of model training, the model and its graph are moved to the selected device (GPU or CPU) and run for a set number of epochs, each divided into a training phase and a validation phase. During training, batches of paired drugs are passed through the network to produce interaction scores, binary cross entropy losses are calculated, and the Adam optimizer updates all parameters via backpropagation. Losses and correct prediction counts are aggregated to report average training loss and accuracy at the end of each epoch, skipping empty or malformed batches for stability. The process then switches to evaluation mode - running the same batches but without gradient updates - to measure validation loss and accuracy.
# training function
def train_kg4mm_model (model, train_loader, val_loader, epochs= 5 ):
device = torch. device ( 'cuda' if torch. cuda . is_available () else 'cpu' )
model = model.to (device )
model. pyg_graph = model. pyg_graph . to (device) criterion = nn. BCELoss ()
optimizer = torch. optim . Adam (model. parameters (), lr= 0.0001 ) for epoch in range (epochs):
# training phase
model.train ( )
train_loss = 0
train_correct = 0
batch_count = 0 for batch in train_loader :
# skip empty batches
if len (batch[ 'drug1_img' ]) == 0 :
print ( "Skipping empty batch" )
continue batch_count += 1 try :
drug1_img = batch[ 'drug1_img' ]. to (device)
drug1_text = batch[ 'drug1_text' ]
drug1_name = batch[ 'drug1_name' ]
drug2_img = batch[ 'drug2_img' ]. to (device)
drug2_text = batch[ 'drug2_text' ]
drug2_name = batch[ 'drug2_name' ]
labels = batch[ 'label' ]. to (device) # forward pass - processing one pair at a time for clarity
batch_size = len (drug1_name)
outputs = torch. zeros (batch_size, 1 , device=device) for i in range (batch_size):
# this loop is for illustration - in practice, handle batch processing more efficiently
output = model (
drug1_img[i]. unsqueeze ( 0 ),
[drug1_text[i]],
drug1_name[i],
drug2_img[i]. unsqueeze ( 0 ),
[drug2_text[i]],
drug2_name[i]
)
outputs[i] = output # calculate loss
loss = criterion (outputs, labels. unsqueeze ( 1 )) # backward and optimize
optimizer.zero_grad ( )
loss. backward ()
optimizer. step () train_loss += loss. item () # calculate accuracy
predictions = (outputs >= 0.5 ). float ()
train_correct += (predictions == labels. unsqueeze ( 1 )). sum (). item () print (f "Batch {batch_count}: Loss: {loss.item():.4f}" ) except Exception as e :
print (f "Error processing batch {batch_count}: {str(e)}" )
import traceback
traceback. print_exc ()
continue avg_train_loss = train_loss / max ( 1 , batch_count)
train_acc = train_correct / max ( 1 , batch_count * batch[ 'drug1_img' ]. size ( 0 )) print (f 'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}' ) # validation phase
model.eval ( )
val_loss = 0
val_correct = 0
val_batch_count = 0 with torch. no_grad ():
for batch in val_loader :
# skip empty batches
iflen(batch['drug1_img']) == 0:
continue val_batch_count += 1 try:
drug1_img = batch['drug1_img'].to(device)
drug1_text = batch['drug1_text']
drug1_name = batch['drug1_name']
drug2_img = batch['drug2_img'].to(device)
drug2_text = batch['drug2_text']
drug2_name = batch['drug2_name']
labels = batch['label'].to(device) # forward pass - processing one pair at a time for clarity
batch_size = len(drug1_name)
outputs = torch.zeros(batch_size, 1, device=device) for i inrange(batch_size):
output = model(
drug1_img[i].unsqueeze(0),
[drug1_text[i]],
drug1_name[i],
drug2_img[i].unsqueeze(0),
[drug2_text[i]],
drug2_name[i]
)
outputs[i] = output # calculate loss
loss = criterion(outputs, labels.unsqueeze(1))
val_loss += loss.item() # calculate accuracy
predictions = (outputs >= 0.5).float()
val_correct += (predictions == labels.unsqueeze(1)).sum().item() except Exceptionase:
print(f"Error processing validation batch {val_batch_count}: {str(e)}")
continue avg_val_loss = val_loss / max(1, val_batch_count)
val_acc = val_correct / max ( 1 , val_batch_count * 4 ) # Assuming batch_size= 4 print (f 'Epoch {epoch+1}/{epochs}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}' ) return model
Prepare data and start training
The prepared dataset then generates pairs of drug examples — each with its molecule image and text summary — and splits them into training and validation sets for performance tracking. The data loader packages these multimodal examples (images, descriptions, and labels) into batches so that they can be smoothly fed into the model. The KG-guided prediction network is instantiated with dimensions derived from the graph’s node and edge types, ensuring that its layers align with the structure of the knowledge graph. Finally, the training loop runs for a fixed number of rounds, alternating between updating the model on the training data and measuring its accuracy on the validation set. This sequence completes the transition from data preparation to active, graph-driven learning.
# initialize dataset and model
ddi _dataset = DDIDataset(drug_ data _df, drug_ drug _interactions, medical_ kg, node _to_ idx)# split dataset into train and validation sets
train _size = int(0.8 \* len(ddi_dataset ))
val _size = len(ddi_dataset ) - train _size
train_ dataset, val _dataset = torch.utils.data.random_ split(ddi _dataset, [train_ size, val _size])# create data loaders
train_ loader = DataLoader(train _dataset, batch_ size=4, shuffle=True, collate _fn=custom_ collate _fn)
val_ loader = DataLoader(val _dataset, batch_ size=4, shuffle=False, collate _fn=custom_ collate _fn)# initialize the model with the DGL graph
num_ node _types = pyg_ graph.x.shape[1]
num _edge_ types = len(edge _type_ to _idx)# initialize the KG-guided multimodal model
model = KGGuidedMultimodalModel(pyg_ graph, num _node_ types, num _edge_ types, node _to_ idx, idx _to_ node)# train the model
trained _model = train_ kg4mm _model(model, train_ loader, val _loader, epochs=5)
Reasoning and explanation
To make a prediction, the model first loads the processed images and text summaries for each drug and determines where each drug is located in the knowledge graph. It then produces a probability score showing how the visual, textual, and graph information work together. At the same time, the system checks the graph for any direct links between the two drugs, any proteins or diseases they are both connected to, and any simple paths connecting them that do not exceed a given length. This probability is converted into a low, medium, or high risk level. An explanation is then constructed that highlights known interaction mechanisms, shared targets, and key graph paths that guide decision making. Finally, the system provides exemplary clinical recommendations based on risk level, clearly demonstrating how the knowledge graph shapes predictions and their interpretations.
def predict_interaction ( model, drug1_name, drug2_name, drug_data_df, medical_kg ):
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
model = model.to(device)
model. eval () # get drug indices
drug1_idx = drug_data_df[drug_data_df[ 'name' ] == drug1_name].index[ 0 ]
drug2_idx = drug_data_df[drug_data_df[ 'name' ] == drug2_name].index[ 0 ] # get drug data
drug1_data = drug_data_df.iloc[drug1_idx]
drug2_data = drug_data_df.iloc[drug2_idx] # prepare images
transform = transforms.Compose([
transforms.Resize(( 224 , 224 )),
transforms.ToTensor(),
transforms.Normalize(mean=[ 0.485 , 0.456 , 0.406 ], std=[ 0.229 , 0.224 , 0.225 ])
]) drug1_img = Image. open (drug1_data[ 'image_path' ]).convert( 'RGB' )
drug1_img = transform(drug1_img).unsqueeze( 0 ).to(device)
drug1_text = [drug1_data[ 'description' ]] drug2_img = Image.open (drug2_data[ 'image_path' ]).convert( ' RGB' )
drug2_img = transform(drug2_img).unsqueeze( 0 ).to(device)
drug2_text = [drug2_data[ 'description' ]] # get knowledge subgraph for the drug pair
knowledge = retrieve_knowledge_subgraph(medical_kg, drug1_name, drug2_name) # make prediction
with torch.no_grad():
interaction_prob = model(
drug1_img,
drug1_text,
drug1_name,
drug2_img,
drug2_text,
drug2_name
) return interaction_prob.item(), knowledgedef explain_interaction_prediction(drug1_name, drug2_name, probability, knowledge):
explanation = f"KG-guided multimodal analysis for interaction between {drug1_name} and {drug2_name} :\n\n" # interpret the probability
if probability > 0.8 :
risk_level = "High"
elif probability > 0.5 :
risk_level = "Moderate"
else :
risk_level = "Low" explanation += f"Interaction Risk Level: {risk_level} (Probability: {probability: .2 f} )\n\n" # explain based on knowledge graph structure
explanation += "Knowledge Graph Analysis:\n" if knowledge[ 'direct_interaction' ]:
mechanism = knowledge[ 'direct_interaction' ].get( 'mechanism' , 'unknown mechanism' )
explanation += f"✓ Direct Connection: The knowledge graph contains a documented interaction between these drugs with {mechanism} .\n\n" if knowledge[ 'common_targets' ]:
explanation += "✓ Common Target Nodes: These drugs connect to shared entities in the knowledge graph:\n"
for target in knowledge[ 'common_targets' ]:
explanation += f" - {target} \n"
explanation += " This graph structure suggests potential interaction through common binding sites or pathways.\n\n" if knowledge[ 'paths' ] and len (knowledge[ 'paths' ]) > 0 :
explanation += "✓ Knowledge Graph Pathways: The model identified these connecting paths in the graph:\n"
for i, path in enumerate (knowledge[ 'paths' ][: 3 ]):
path_str = " → " .join(path)
explanation += f" - Path {i+ 1 } : {path_str} \n"
explanation += " These graph structures guided the multimodal feature integration for prediction.\n\n" # focus on how KG structure guided the interpretation
explanation += "Multimodal Integration Process:\n"
explanation += " - Knowledge graph structure determined which drug properties were most relevant\n"
explanation += " - Graph neural networks analyzed the local neighborhood of both drug nodes\n"
explanation += " - Node position in the graph guided the weighting of visual and textual features\n\n" # clinical implications (example - in a real system, this would be more comprehensive)
if probability > 0.5 :
explanation += "Clinical Recommendations (based on graph analysis):\n"
explanation += " - Consider alternative medications not connected in similar graph patterns\n"
explanation += " - If co-administration is necessary, monitor for interaction effects\n"
explanation += " - Review other drugs connected to the same nodes for potential complications\n"
else :
explanation += "Clinical Recommendations (based on graph analysis):\n"
explanation += " - Standard monitoring advised\n"
explanation += " - The knowledge graph structure suggests minimal interaction concerns\n" return explanation
result
To illustrate the full workflow, two drugs are selected and their pre-generated images and text summaries are loaded and pre-processed as they were during training. These multimodal inputs are then passed through the trained model - now in evaluation mode - to produce a probability score quantifying their risk of interaction. Meanwhile, for visualization and interpretation, the process extracts the relevant part of the knowledge graph by collecting all direct connections, shared biological targets, and any simple paths connecting them that do not exceed a given length, and then augments this subgraph by adding a layer of direct neighbors to obtain broader context.
The extracted subgraphs are laid out with a clear color scheme that effectively distinguishes between two drugs, proteins, diseases, and other entities, making the network structure clear at a glance, enhancing readability and analysis efficiency. This is followed by a clear natural language explanation that associates probability scores with these graph features by highlighting any documented interaction mechanisms, shared targets, and key connection paths. Risk estimates, color-coded visualizations, and narrative explanations together illustrate how the topology of the knowledge graph guides the fusion of visual and textual signals and provides a transparent rationale for the model's predictions.
# example usage
drug _pair = ("Goserelin", "Desmopressin")
prob, knowledge = predict_ interaction(trained _model, drug_ pair[0], drug _pair[1], drug_ data _df, medical_ kg)print(f"Predicted interaction probability between {drug _pair[0]} and {drug_ pair[1]}: {prob:.4f}")print("\nKnowledge Graph Structure Analysis:")
print(f"Direct connection: {knowledge['direct _interaction']}")
print(f"Common target nodes: {knowledge['common_ targets']}")
print(f"Graph paths connecting drugs:")
for path in knowledge['paths']:
print(f" {' -> '.join(path)}")# visualize the subgraph for these drugs to show the KG-guided approach
plt.figure(figsize=(12, 8))
subgraph _nodes = set([drug_ pair[0], drug _pair[1]])
# add intermediate nodes in paths to highlight the KG structure
for path in knowledge['paths']:
subgraph_ nodes.update(path) # add a level of neighbors to show context in KG
neighbors _to_ add = set()
for node in subgraph _nodes:
if node in medical_ kg:
neighbors _to_ add.update(list(medical _kg.neighbors(node))[:3])
subgraph_ nodes.update(neighbors _to_ add)subgraph = medical _kg.subgraph(subgraph_ nodes)# use different colors for node types to emphasize KG structure
node_colors = []
for node in subgraph.nodes():
if node == drug_ pair[0] or node == drug _pair[1]:
node_ colors.append('lightcoral')
elif subgraph.nodes[node].get('type') == 'protein':
node _colors.append('lightblue')
elif subgraph.nodes[node].get('type') == 'disease':
node_ colors.append('lightgreen')
else:
node _colors.append('lightgray')pos = nx.spring_ layout(subgraph, seed=42)
nx.draw(subgraph, pos, with _labels=True, node_ color=node _colors,
node_ size=2000, arrows=True, arrowsize=20)edge _labels = {(s, o): subgraph[ s ][ o ][ 'relation' ] for s, o in subgraph.edges()}
nx.draw_ networkx _edge_ labels(subgraph, pos, edge _labels=edge_ labels)plt.title(f"Knowledge Graph Structure Guiding {drug _pair[0]} and {drug_ pair[1]} Interaction Analysis")
plt.savefig('kg _guided_ interaction _analysis.png')
plt.show()# show explanation
explanation = explain_ interaction _prediction(drug_ pair[0], drug _pair[1], prob, knowledge)
print(explanation)
When tested on Goserelin and Desmopressin, the model returned a probability of 0.54, classifying them as a medium risk pair. The knowledge graph revealed a direct "interacts_with" relationship between the two drugs, which was specifically described/labeled as "increases_anticoagulant_effect", with no shared protein or disease connections, so the model focused on that mechanism. In the subgraph, the two drugs are highlighted in red, and the single directed edge is highlighted, making it clear which relationship drives the prediction.
in conclusion
KG4MM research shows that putting a knowledge graph at the core of the workflow can better fuse molecular images and text than single-source approaches. Each prediction is supported by clear graph evidence—direct edges, shared targets, and connected pathways—that links the results to real biological relationships. By doing so, KG4MM provides stronger predictive power and built-in interpretability in fields such as biochemistry, materials science, and medical diagnostics.