Knowledge Graph + Multimodality: Taking Drug Prediction as an Example

Written by
Iris Vance
Updated on:June-20th-2025
Recommendation

Explore the revolutionary application of knowledge graphs and multimodal learning in the field of drug prediction.

Core content:
1. The core role of knowledge graphs in multimodal learning and its potential in drug prediction
2. How KG4MM integrates molecular images and text descriptions to improve the accuracy of drug interaction prediction
3. Use graph neural networks to connect knowledge graphs and multimodal data to achieve more explainable prediction results

Yang Fangxian
Founder of 53A/Most Valuable Expert of Tencent Cloud (TVP)



 

introduction

Knowledge graphs have become an important tool for representing how different entities are related to each other. They encode information as nodes and edges, which visually show the connections between entities. Knowledge Graphs for Multimodal Learning (KG4MM) builds on this by using knowledge graphs to guide the process of learning from images and text. In KG4MM, the graph acts as a "map" that clearly marks the parts of each data type to focus on during training. This guidance helps the system focus on the most relevant features in images and the most informative words in text.

In the field of drug interaction prediction, KG4MM offers clear advantages. The graph structure integrates the molecular images and text descriptions of drugs in a unified framework. This unified perspective helps improve prediction accuracy because it captures both chemical structure and pharmacological background information. In addition, the knowledge graph creates a transparent path from input to output, making it easier to understand why the model makes a specific prediction.

This article will explain how KG4MM is used in practice to predict pairwise drug interactions. It will walk through the steps of building a knowledge graph and integrating molecular and textual information. Through concrete examples, it will illustrate how multimodal learning guided by knowledge graphs can solve practical challenges in medicine and healthcare research. The goal is to show how KG4MM can improve prediction accuracy and interpretability in real-world drug interaction tasks.

Methodology

The KG4MM approach places a knowledge graph at the heart of the entire process. The graph guides how each data type is processed and understood. In the drug interaction example, each drug node in the graph is associated with two pieces of information. The first is a molecular image derived from its SMILES formula, and the second is a textual description containing its class, functional groups, and other key details.

KG4MM is unique in that it leverages graph neural networks (GNNs) to connect graph structures and multimodal data. Based on the placement of a drug in the graph, the GNN determines which parts of its image and which words in its description are most worthy of attention. The edges in the graph—showing how the drug is related to proteins, diseases, and other drugs—help the network determine which visual and textual features are most important. In this way, the knowledge graph does more than just provide additional context; it actively guides the model to focus on the most informative data elements.

The strength of KG4MM lies in its ability to combine pattern recognition neural networks with explicit relationship graphs. GNNs have significant advantages in processing contextual data, so the model can be built on existing knowledge of drug interactions and biochemical properties. This guided learning not only improves prediction accuracy, but also produces clear and interpretable results by highlighting the specific graph connections that affect each prediction.

Implementation Overview

The system is built around a core knowledge graph that integrates all components. This graph captures directed relationships between drugs, proteins, and diseases, such as a drug "binds to" a protein, "inhibits" a target, or "treats" a disease. By placing the graph at the core of the design, every step of the process relies on its structured medical knowledge graph.

To prepare the data, the system associates two representations to each drug node: a molecular image and a textual description. The first is a molecular image generated from its SMILES molecular formula using RDKit. The second is a textual description that summarizes the drug class, functional groups, and other relevant details. Both the image and the text are directly connected to the corresponding drug node in the graph, ensuring that the visual and linguistic features are consistent with the underlying knowledge structure.

Modeling the graph itself relies on graph convolutional networks (GCNs). These networks learn from the position of each node and its connections in the graph, creating embeddings that encode how drugs, proteins, and diseases are related to each other. Meanwhile, multimodal encoders convert images and text into feature vectors: a ResNet processes molecular images, while a BERT model converts text descriptions.

Finally, a graph attention network (GAT) fuses the graph embedding with visual and textual features. The attention mechanism exploits the graph structure to weight the most important features from each modality. The combined representation is then fed into a prediction module that determines whether two drugs will interact. At the same time, the attention weights reveal which graph connections, image regions, or text elements contribute most to the model's decision, providing a clear explanation for each prediction.

Detailed implementation

This step ensures that all necessary deep learning, graph processing, and cheminformatics packages are available in the environment. The implementation begins by installing and importing the required libraries. It installs PyTorch and torchvision for neural networks, HuggingFace Transformers for text encoding, NetworkX and torch-geometric for graph manipulation, RDKit and OpenBabel for working with molecular structures, and supporting libraries such as pandas, NumPy, and Matplotlib. Once installed, import the required libraries and modules for use in subsequent units.

# install necessary packages
!pip install torch torchvision transformers networkx spacy rdflib rdkit pillow scikit-learn matplotlib seaborn torch-geometric
# pip did not work
!apt-get install openbabel
!pip install openbabel-wheel#  import  libraries
import  torch
import  torch.nn as nn ​
from  torch.utils.data import DatasetDataLoader
import  torchvision. models as  models
import  torchvision. transforms as  transforms
from  transformers  import BertModelBertTokenizer
import  networkx  as  nx
import  numpy  as  np
import  matplotlib.pyplot as  plt
import  pandas  as  pd
import  json
import  os
from  rdkit  import Chem
from  rdkit.Chem import Draw
from PIL import Image
import  io
import  base64
from  openbabel  import  openbabel
from  torch_geometric.data import Data
import  torch_geometric. nn as  geom_nn

Data preparation

First, a directory is created to store drug images. Then a simplified DrugBank sample is downloaded from a public repository and saved as a TSV file. This file is loaded into a pandas DataFrame to generate a table containing each drug's unique identifier, name, InChI string for molecular structure, and descriptive metadata such as category and group. This structured dataset lays the foundation for generating visual and textual representations in subsequent steps.

# create directory for data storage

!mkdir -p data/drug _images# download DrugBank sample data (simplified version for demonstration)
!wget -q -O data/drugbank_
sample.tsv https://raw.githubusercontent.com/dhimmel/drugbank/gh-pages/data/drugbank-slim.tsv# load DrugBank data
drug _df = pd.read_ csv('data/drugbank _sample.tsv', sep='\t')

Molecular structure conversion

The molecular structures are provided in InChI format. These representations need to be converted to SMILES format using OpenBabel. SMILES stands for Simplified Molecular Input Line Entry System and provides a concise, text-based way to describe chemical structures. SMILES strings are compatible with tools such as RDKit, which can generate molecular images from SMILES strings. The following code shows how to do this conversion.

# create a SMILES column by converting InChI to SMILES

def inchi _to_ smiles _openbabel(inchi_ str):
try: # create Open Babel OBMol object from InChI
obConversion = openbabel.OBConversion()
obConversion.SetInAndOutFormats("inchi", "smiles")
mol = openbabel.OBMol() # convert InChI to molecule # also remove extra newlines or spaces
if obConversion.ReadString(mol, inchi _str):
return obConversion.WriteString(mol).strip()
else:
return None
except Exception as e:
print(f"Error converting InChI to SMILES: {inchi_
str}. Error: {e}")
return None# apply the conversion to each InChI in the dataframe
drug _df['smiles'] = drug_ df['inchi'].apply(inchi _to_ smiles _openbabel)

Knowledge graph construction

The system builds a directed medical knowledge graph to capture the relationships between drugs, proteins, and diseases. Each node represents a drug, protein, or disease, and each edge encodes an interaction, such as binds_to, inhibits, or treats. These connections store expert knowledge about how drugs affect biological targets and diseases.

The graph serves as a source of structured relational information that the model leverages alongside image and text features. By explicitly representing domain knowledge, the graph enhances predictive accuracy and the ability to explain why two drugs might interact.

# initialize a medical knowledge graph

medical _kg = nx.DiGraph()# extract drug entities from DrugBank

# limit to 50 drugs for demo

drug_
entities = drug _df['name'].dropna().unique().tolist()[:50]# create drug nodes
for drug in drug_
entities:
medical _kg.add_ node(drug, type='drug')# add biomedical entities (proteins, targets, diseases)
protein _entities = ["Cytochrome P450", "Albumin", "P-glycoprotein", "GABA Receptor",
"Serotonin Receptor", "Beta-Adrenergic Receptor", "ACE", "HMGCR"]
disease_
entities = ["Hypertension", "Diabetes", "Depression", "Epilepsy",
"Asthma", "Rheumatoid Arthritis", "Parkinson's Disease"]for protein in protein _entities:
medical_
kg.add _node(protein, type='protein')for disease in disease_ entities:
medical _kg.add_ node(disease, type='disease')# add relationships (based on common drug mechanisms and interactions)

# drug-protein relationships

drug _protein_ relations = [
("Warfarin", "binds_to ", "Albumin"),
("Atorvastatin", "inhibits", "HMGCR"),
("Diazepam", "modulates", "GABA Receptor"),
("Fluoxetine", "inhibits", "Serotonin Receptor"),
("Phenytoin", "induces", "Cytochrome P450"),
("Metoprolol", "blocks", "Beta-Adrenergic Receptor"),
("Lisinopril", "inhibits", "ACE"),
("Rifampin", "induces", "P-glycoprotein"),
("Carbamazepine", "induces", "Cytochrome P450"),
("Verapamil", "inhibits", "P-glycoprotein")
]# drug-disease relationships
drug_disease_relations
= [
("Lisinopril", "treats", "Hypertension"),
("Metformin", "treats", "Diabetes"),
("Fluoxetine", "treats", "Depression"),
("Phenytoin", "treats", "Epilepsy"),
("Albuterol", "treats", "Asthma"),
("Methotrexate", "treats", "Rheumatoid Arthritis"),
("Levodopa", "treats", "Parkinson's Disease")
]# known drug-drug interactions (based on actual medical knowledge)
drug_drug_interactions
= [
("Goserelin", "interacts_
with", "Desmopressin", "increases _anticoagulant_ effect"),
("Goserelin", "interacts _with", "Cetrorelix", "increases_ bleeding _risk"),
("Cyclosporine", "interacts_
with", "Felypressin", "decreases _efficacy"),
("Octreotide", "interacts_
with", "Cyanocobalamin", "increases _hypoglycemia_ risk"),
("Tetrahydrofolic acid", "interacts _with", "L-Histidine", "increases_ statin _concentration"),
("S-Adenosylmethionine", "interacts_
with", "Pyruvic acid", "decreases _efficacy"),
("L-Phenylalanine", "interacts_
with", "Biotin", "increases _sedation"),
("Choline", "interacts_
with", "L-Lysine", "decreases _efficacy")
]# add all relationships to the knowledge graph
for s, r, o in drug_
protein _relations:
if s in medical_
kg and o in medical _kg:
medical_
kg.add _edge(s, o, relation=r)for s, r, o in drug_ disease _relations:
if s in medical_
kg and o in medical _kg:
medical_
kg.add _edge(s, o, relation=r)for s, r, o, mechanism in drug_ drug _interactions:
if s in medical_
kg and o in medical _kg:
medical_
kg.add _edge(s, o, relation=r, mechanism=mechanism)

Multimodal data processing

Each drug is represented by three complementary data types. First, its SMILES representation is converted into a molecule object and rendered into an image using RDKit.

function  to generate molecular structure images  using RDKit 
def  generate_molecule_image (smiles_string, size=( 224224 )):
    try :
        mol =  Chem . MolFromSmiles (smiles_string)
        if mol : 
            img =  Draw . MolToImage (mol, size=size)
            return  img
        else :
            return None 
    except :
        return None 

Second, descriptive text was constructed by combining the drug name, class, group information, and any available metadata.

function  to create text description  for  drugs combining various information
def  create_drug_description (row):
    description = f "Drug name: {row['name']}. " if  pd. notna (row. get ( 'category' )):    
        description += f "Category: {row['category']}. " if  pd. notna (row. get ( 'groups' )):    
        description += f "Groups: {row['groups']}. " if  pd. notna (row. get ( 'description' )):    
        description += f "Description: {row['description']}"

Third, the graph is embedded in it. Specifically, each node and relationship is initially a random vector, and then iteratively adjusted so that for each real connection, the vector of an entity plus the vector of its relationship makes it close to the vector of the entity it is connected to. After many iterations, this forms an embedding space where connected elements are naturally clustered and the direction of the relationship is encoded by the relationship vector. The result is a pair of lookup tables that map each node and relationship to compact, trainable coordinates that reflect the complete structure of the knowledge graph.

# convert NetworkX graph to PyG graph for modern graph neural network processing

def convert _nx_ to _pyg(nx_ graph): # create node mappings
node _to_ idx = {node: i for i, node in enumerate(nx _graph.nodes())} # create edge lists
src_nodes
= []
dst_nodes = []
edge_types
= []
edge _type_ to _idx = {} for u, v, data in nx_ graph.edges(data=True):
relation = data.get('relation', 'unknown')
if relation not in edge _type_ to _idx:
edge_
type _to_ idx[relation] = len(edge ​​_type_ to _idx)
src_
nodes.append(node ​​_to_ idx[u])
dst _nodes.append(node_ to _idx[v])
edge_
types.append(edge ​​_type_ to _idx[relation]) # create PyG graph
edge_
index = torch.tensor([src _nodes, dst_ nodes], dtype=torch.long)
edge _type = torch.tensor(edge_ types, dtype=torch.long) # create node features
node_types = []
for node in nx_
graph.nodes():
node _type = nx_ graph.nodes[node].get('type', 'unknown')
node _types.append(node_ type) # one-hot encode node types
unique _node_ types = sorted(set(node ​​_types))
node_
type _to_ idx = {nt: i for i, nt in enumerate(unique _node_ types)}
node _type_ features = torch.zeros(len(node ​​_types), len(unique_ node _types))
for i, nt in enumerate(node_
types):
node _type_ features[i, node _type_ to _idx[nt]] = 1.0 # create PyG Data object with the proper attributes
g = Data(
edge_index
=edge_index ,
edge_
type=edge _type,
x=node_
type _features # node features in PyG are stored in 'x'
) # create reverse mappings for later use
idx_
to _node = {idx: node for node, idx in node_ to _idx.items()}
idx_
to _edge_ type = {idx: edge _type for edge_ type, idx in edge _type_ to _idx.items()} return g, node_ to _idx, idx_ to _node, edge_ type _to_ idx, idx _to_ edge _type# convert medical_ kg to DGL graph
pyg _graph, node_ to _idx, idx_ to _node, edge_ type _to_ idx, idx _to_ edge _type = convert_ nx _to_ pyg(medical _kg)

These visual, textual, and structured representations are saved so that the model can fuse them for interaction prediction.

# process drug data to create multi-modal representations

drug _data = []for idx, row in drug_ df.iterrows():
if row['name'] in drug _entities and pd.notna(row.get('smiles')): # generate molecule image
img = generate_
molecule _image(row['smiles']) if img:
img_
path = f"data/drug _images/{row['drugbank_ id']}.png"
img.save(img _path) # Create text description
description = create_
drug _description(row) # Store drug information
drug_
data.append({
'id': row['drugbank _id'],
'name': row['name'],
'smiles': row['smiles'],
'description': description,
'image_
path': img _path
})drug_
data _df = pd.DataFrame(drug_ data)

Encoder Development

MultimodalNodeEncoder creates a single encoder that converts each node's molecule image and its textual summary into compatible feature vectors. First, it applies a deep convolutional network to the raw chemical graph to refine it into a compact visual fingerprint. At the same time, it processes the drug description through a pre-trained language model to extract a semantic summary. The outputs of both are then mapped into the same vector space so that visual and textual signals can be meaningfully combined under the guidance of the knowledge graph structure.

# processes visual and textual features  for  nodes
class MultimodalNodeEncoder (nn. Module ): def  __init__ (self, output_dim= 128 ):
        super ( MultimodalNodeEncoder , self). __init__ ()
        # image  encoder  ( ResNet )
        resnet = models. resnet18 (pretrained= True )
        # remove the final fully connected layer to get  512  features
        self. image_encoder  = nn. Sequential (* list (resnet. children ())[:- 1 ])
        self. image_projection  = nn. Linear ( 512 , output_dim) # text  encoder  ( BERT )
        self. tokenizer  =  BertTokenizer . from_pretrained ( 'bert-base-uncased' )
        self. text_encoder  =  BertModel . from_pretrained ( 'bert-base-uncased' )
        #  BERT  base outputs  768  features
        self. text_projection  = nn. Linear ( 768 , output_dim) def  forward (self, image, text):
        # image encoding
        img_features = self. image_encoder (image). squeeze (- 1 ). squeeze (- 1 )
        img_features = self. image_projection (img_features) # text encoding
        encoded_input = self. tokenizer (text, padding= True , truncation= True ,
                                      return_tensors= "pt" , max_length= 128 )
        # move encoded input to the same device  as  the image
        input_ids = encoded_input[ 'input_ids' ]. to (image. device )
        attention_mask = encoded_input[ 'attention_mask' ]. to (image. device ) text_outputs = self. text_encoder (input_ids=input_ids,
                                       attention_mask=attention_mask)
        # use the [ CLS ] token  embedding  (first token)
        text_features = text_outputs. last_hidden_state [:,  0 , :]
        text_features = self. text_projection (text_features)                 return  img_features, text_features

Model Integration

The KG-guided multimodal model fuses the visual, textual, and type embeddings of each node and predicts drug-drug interactions under the guidance of the knowledge graph. It first projects the image and textual outputs of each node into a shared space and assigns a separate embedding to its node type. These embeddings are then propagated across the graph, allowing each node to fuse its own features with the signals collected from its neighbors. The attention step reweights these fused features based on the strength and type of the connection. When evaluating a pair of drugs, the model takes their refined node embeddings, combines them through concatenation, element-wise products, and differences, and then feeds the results into the prediction head to produce the probability of interaction. By letting the topology of the graph determine how multimodal signals are fused, the predictions generated by the model are both accurate and directly traceable to the underlying network structure.

# define  KG -guided  Multimodal Model
class KGGuidedMultimodalModel (nn. Module ): def  __init__ (self, pyg_graph, num_node_types, num_edge_types, node_to_idx, idx_to_node, hidden_dim= 128 ):
        super ( KGGuidedMultimodalModel , self). __init__ ()
        self.pyg_graph  = pyg_graph
        self.node_to_idx  = node_to_idx
        self.idx_to_node  = idx_to_node
        self. hidden_dim  = hidden_dim # multimodal encoder  for  processing node-associated data
        self. multimodal_encoder  =  MultimodalNodeEncoder (output_dim=hidden_dim) # node type embeddings
        self. node_type_embedding  = nn. Embedding (num_node_types, hidden_dim) #  Graph Neural Network  layers  for  knowledge graph  processing  ( PyG GCNConv  instead  of  dglnn. GraphConv )
        self. gnn_layers  = nn. ModuleList ([
            geom_nn.GCNConv ( hidden_dim, hidden_dim),
            geom_nn.GCNConv ( hidden_dim, hidden_dim),
        ]) #  Graph Attention Network for  integrating multimodal features  with  graph  structure  ( PyG GATConv )
        # explicitly set output dimension so total output is  hidden_dim  (not hidden_dim * num_heads)
        self.gat_layer  = geom_nn.GATConv ( hidden_dim , hidden_dim  // 4, heads=4) # relation prediction layer - updated to match the actual input dimensions we'll have
        self. relation_prediction  = nn. Sequential (
            nn. Linear (hidden_dim *  4 , hidden_dim *  2 ),
            nn.ReLU () ,
            nn.Dropout ( 0.3 ) ,
            nn. Linear (hidden_dim *  2 , hidden_dim),
            nn.ReLU () ,
            nn.Dropout ( 0.3 ) ,
            nn.Linear ( hidden_dim1 )
        ) def  get_node_representation (self, node_name, image= None , text= None ):
        if  node_name not  in  self. node_to_idx :
            # handle unknown nodes
            return  torch. zeros (self. hidden_dim , device=self. pyg_graph . edge_index . device ) node_idx = self. node_to_idx [node_name] # get node type features - use x instead  of  ndata[ 'type' ]
        node_type_feat = self. pyg_graph . x [node_idx]
        node_type_embedding = self. node_type_embedding (torch. argmax (node_type_feat)) #  if  multimodal data is provided, process it
        if  image is not  None  and text is not  None :
            img_feat, text_feat = self. multimodal_encoder (image, text) # squeeze out the batch dimension to match shapes
            img_feat = img_feat.squeeze ( 0 )
            text_feat = text_feat. squeeze ( 0 ) # knowledge graph structure guides how multimodal features are integrated
            # use node_type_embedding  as  a query to attend to multimodal features
attention_weights =             torch.softmax (
                torch.matmul (
                    torch.stack ( [img_feat, text_feat, node_type_embedding]),
                    node_type_embedding
                ),
                dim = 0
            ) # weighted combination  of  features
            combined_feat = (
                attention_weights[ 0 ] * img_feat +
                attention_weights[ 1 ] * text_feat +
                attention_weights[ 2 ] * node_type_embedding
            )                         return  combined_feat
        else :
            #  for  nodes without multimodal data, just use type embedding
            return  node_type_embedding def  forward (self, drug1_image, drug1_text, drug1_name, drug2_image, drug2_text, drug2_name):
        # process the entire graph first
        device = self. pyg_graph . edge_index . device
        x = torch. zeros ((self. pyg_graph . x . size ( 0 ), self. hidden_dim ), device=device) # initialize known node features
        for  i, node_name  in enumerate ([drug1_name, drug2_name]):
            if  node_name  in  self. node_to_idx :
                node_idx = self. node_to_idx [node_name]
                if  i ==  0 :
                    x[node_idx] = self. get_node_representation (node_name, drug1_image, drug1_text)
                else :
                    x[node_idx] = self. get_node_representation (node_name, drug2_image, drug2_text) # apply graph convolutions to propagate information -  PyG  style
edge_index =         self.pyg_graph.edge_index
        for  layer  in  self. gnn_layers :
            x =  layer (x, edge_index)
            x = torch.relu ( x) # apply graph attention to integrate features -  PyG  style
        x = self.gat_layer ( x, edge_index) # get final representations  for  the two drugs
        drug1_idx = self. node_to_idx . get (drug1_name,  0 )
        drug2_idx = self. node_to_idx . get (drug2_name,  0 ) drug1_repr = x[drug1_idx]
        drug2_repr = x[drug2_idx] # predict interaction
        # concatenate representations  in  multiple ways to capture relationship
        concat_repr = torch.cat ( [
            drug1_repr,
            drug2_repr,
            drug1_repr * drug2_repr,
            torch. abs (drug1_repr - drug2_repr)
        ], dim= 0 ) interaction_prob = torch. sigmoid (self. relation_prediction (concat_repr. unsqueeze ( 0 )). squeeze ())
        return  interaction_prob

Knowledge Extraction

The algorithm constructs a focused subgraph based on how the two drugs are related in the larger graph. It first looks for any direct edges between the two drugs in the graph and records their properties if found. Next, it identifies proteins or diseases associated with both drugs, revealing shared mechanisms. Finally, it traces all simple paths of no more than a given length to reveal indirect connections through intermediate nodes. The result is a compact network of key nodes and edges that captures the domain knowledge behind the predicted interactions and guides downstream layers to emphasize the most relevant multimodal features.

function  to retrieve knowledge subgraph relevant to a drug pair
def  retrieve_knowledge_subgraph (graph, drug1, drug2, max_path_length= 3 ):
    relevant_knowledge = {
        'direct_interaction'None ,
        'common_targets' : [],
        'paths' : []
    } # check  for  direct interaction
    if  graph. has_edge (drug1, drug2):
        edge_data = graph. get_edge_data (drug1, drug2)
        relevant_knowledge[ 'direct_interaction' ] = edge_data # find common  targets  (proteins, diseases)
    drug1_neighbors =  set (graph. neighbors (drug1))  if  drug1  in  graph  else set ()
    drug2_neighbors =  set (graph. neighbors (drug2))  if  drug2  in  graph  else set () common_neighbors = drug1_neighbors. intersection (drug2_neighbors)
    for  common_node  in common_neighbors :
        node_type = graph. nodes [common_node]. get ( 'type''' )
        if  node_type ==  'protein'  or node_type ==  'disease' :
            relevant_knowledge[ 'common_targets' ]. append (common_node) # find paths between  drugs  (up to max_path_length)
    try :
        paths =  list (nx. all_simple_paths (graph, drug1, drug2, cutoff=max_path_length))
        relevant_knowledge[ 'paths' ] = paths
    except  (nx. NetworkXError , nx. NodeNotFound ):
        #  Handle  cases where paths  do  not exist or nodes are not  in  graph
        pass         return  relevant_knowledge

Batch data processing

The function is responsible for preparing each training batch, first discarding all samples that failed to load or were incomplete. It then stacks all valid molecule images into a batched tensor for the drug pair, while collecting their corresponding text summaries and identifiers into parallel lists. Interaction labels are similarly combined into a single tensor. By returning a unified dictionary containing these batched components - or empty placeholders if the remaining samples are invalid - it ensures that the model always receives well-structured, homogeneous input, despite the heterogeneity and occasional missingness of the underlying data.

# custom collate  function  to handle  None  values
def  custom_collate_fn (batch):
    # filter out  None  values
    batch = [item  for  item  in  batch  if  item is not  None ] #  return  empty batch  if  all items were  None
    if len (batch) ==  0 :
        return  {
            'drug1_img' : torch. tensor ([]),
            'drug1_text' : [],
            'drug1_name' : [],
            'drug2_img' : torch. tensor ([]),
            'drug2_text' : [],
            'drug2_name' : [],
            'label' : torch.tensor ([] )
        } # process non- None  items
    drug1_imgs = torch. stack ([item[ 'drug1_img'for  item  in  batch])
    drug1_texts = [item[ 'drug1_text'for  item  in  batch]
    drug1_names = [item[ 'drug1_name'for  item  in  batch] drug2_imgs = torch. stack ([item[ 'drug2_img'for  item  in  batch])
    drug2_texts = [item[ 'drug2_text'for  item  in  batch]
    drug2_names = [item[ 'drug2_name'for  item  in  batch] labels = torch. stack ([item[ 'label'for  item  in  batch])         return  {
        'drug1_img' : drug1_imgs,
        'drug1_text' : drug1_texts,
        'drug1_name' : drug1_names,
        'drug2_img' : drug2_imgs,
        'drug2_text' : drug2_texts,
        'drug2_name' : drug2_names,
        'label' : labels
    }

Dataset preparation

The training examples combine all true interactions and a set of matched random non-interacting pairs. The process first extracts all known drug interaction pairs, then samples an equal number of negative pairs to balance the dataset. When an example is obtained, the molecular images and text summaries for each drug are loaded and preprocessed, skipping any pairs with missing data, and generating records containing images, descriptions, names, and a binary label for both drugs. By pairing positive and negative examples, applying consistent image transformations, and robustly handling missing data, the dataset provides reliable, ready-to-use batches for training interaction prediction models.

# define dataset  for DDI  prediction 
class DDIDataset ( Dataset ): def  __init__ (self, drug_data_df, drug_drug_interactions, medical_kg, node_to_idx, transform= None ):
        self.drug_data  = drug_data_df
        self. drug_name_to_idx  = {row[ 'name' ]: i  for  i, row  in  drug_data_df. iterrows ()}
        self.node_to_idx  = node_to_idx
        self. transform  = transform or transforms. Compose ([
            transforms.Resize (( 224224 )) ,
            transforms.ToTensor () ,
            transforms.Normalize (mean=[ 0.4850.4560.406 ], std=[ 0.2290.2240.225 ] )
        ]) # create pairs  of  drugs  with  interaction labels
        self.pairs  = [ ]
        drug_names =  list (self. drug_name_to_idx . keys ()) # positive  samples  (known interactions)
        for  interaction  in drug_drug_interactions :
            drug1, _, drug2, _ = interaction
            if  drug1  in  drug_names and drug2  in drug_names :
                #  1 for  positive interaction
                self. pairs . append ((drug1, drug2,  1 ))
        positive_pairs =  set ((d1, d2)  for  d1, d2, _  in  self. pairs ) # generate some negative samples
        np.random.seed ( 42 )
        neg_count =  0
        max_neg =  len ( self. pairs )
        while  neg_count <  max_neg :
            i, j = np. random . choice ( len (drug_names),  2 , replace= False )
            drug1, drug2 = drug_names[i], drug_names[j]
            if  (drug1, drug2) not  in  positive_pairs  and  (drug2, drug1) not  in positive_pairs :
                #  0 for  negative interaction
                self. pairs . append ((drug1, drug2,  0 ))
                neg_count +=  1                     def  __len__ (self):
        return len (self. pairs ) def  __getitem__ (self, idx):
        try :
            drug1_name, drug2_name, label = self. pairs [idx] # get drug1 data
            drug1_idx = self. drug_name_to_idx [drug1_name]
            drug1_data = self. drug_data . iloc [drug1_idx] # load drug1 image  with  error handling
            try :
                drug1_img =  Image . open (drug1_data[ 'image_path' ]). convert ( 'RGB' )
                drug1_img = self. transform (drug1_img)
            except  Exception as e :
                print (f "Error loading drug1 image for {drug1_name}: {str(e)}" )
                return None                             drug1_text = drug1_data[ 'description' ] # get drug2 data
            drug2_idx = self. drug_name_to_idx [drug2_name]
            drug2_data = self. drug_data . iloc [drug2_idx] # load drug2 image  with  error handling
            try :
                drug2_img =  Image . open (drug2_data[ 'image_path' ]). convert ( 'RGB' )
                drug2_img = self. transform (drug2_img)
            except  Exception as e :
                print (f "Error loading drug2 image for {drug2_name}: {str(e)}" )
                return None                             drug2_text = drug2_data[ 'description' ]                         return  {
                'drug1_img' : drug1_img,
                'drug1_text' : drug1_text,
                'drug1_name' : drug1_name,
                'drug2_img' : drug2_img,
                'drug2_text' : drug2_text,
                'drug2_name' : drug2_name,
                'label' : torch. tensor (label, dtype=torch. float32 )
            }
        except  Exception as e :
            print (f "Error in __getitem__ for index {idx}: {str(e)}" )
            return None 

Model Training

At the beginning of model training, the model and its graph are moved to the selected device (GPU or CPU) and run for a set number of epochs, each divided into a training phase and a validation phase. During training, batches of paired drugs are passed through the network to produce interaction scores, binary cross entropy losses are calculated, and the Adam optimizer updates all parameters via backpropagation. Losses and correct prediction counts are aggregated to report average training loss and accuracy at the end of each epoch, skipping empty or malformed batches for stability. The process then switches to evaluation mode - running the same batches but without gradient updates - to measure validation loss and accuracy.

# training  function
def  train_kg4mm_model (model, train_loader, val_loader, epochs= 5 ):
    device = torch. device ( 'cuda' if  torch. cuda . is_available ()  else 'cpu' )
    model = model.to (device )
    model. pyg_graph  = model. pyg_graph . to (device) criterion = nn. BCELoss ()
    optimizer = torch. optim . Adam (model. parameters (), lr= 0.0001 )         for  epoch  in range (epochs):
        # training phase
        model.train ( )
        train_loss =  0
        train_correct =  0
        batch_count =  0 for  batch  in train_loader :                
            # skip empty batches
            if len (batch[ 'drug1_img' ]) ==  0 :
                print ( "Skipping empty batch" )
                continue                             batch_count +=  1 try :                        
                drug1_img = batch[ 'drug1_img' ]. to (device)
                drug1_text = batch[ 'drug1_text' ]
                drug1_name = batch[ 'drug1_name' ]
                drug2_img = batch[ 'drug2_img' ]. to (device)
                drug2_text = batch[ 'drug2_text' ]
                drug2_name = batch[ 'drug2_name' ]
                labels = batch[ 'label' ]. to (device) # forward pass - processing one pair at a time  for  clarity
                batch_size =  len (drug1_name)
                outputs = torch. zeros (batch_size,  1 , device=device)                                 for  i  in range (batch_size):
                    #  this  loop is  for  illustration -  in  practice, handle batch processing more efficiently
                    output =  model (
                        drug1_img[i]. unsqueeze ( 0 ),
                        [drug1_text[i]],
                        drug1_name[i],
                        drug2_img[i]. unsqueeze ( 0 ),
                        [drug2_text[i]],
                        drug2_name[i]
                    )
                    outputs[i] = output # calculate loss
                loss =  criterion (outputs, labels. unsqueeze ( 1 )) # backward and optimize
                optimizer.zero_grad ( )
                loss. backward ()
                optimizer. step () train_loss += loss. item () # calculate accuracy
                predictions = (outputs >=  0.5 ). float ()
                train_correct += (predictions == labels. unsqueeze ( 1 )). sum (). item ()                                 print (f "Batch {batch_count}: Loss: {loss.item():.4f}" ) except  Exception as e :
                print (f "Error processing batch {batch_count}: {str(e)}" )
                import  traceback
                traceback. print_exc ()
                continue                         avg_train_loss = train_loss /  max ( 1 , batch_count)
        train_acc = train_correct /  max ( 1 , batch_count * batch[ 'drug1_img' ]. size ( 0 ))                 print (f 'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}' ) # validation phase
        model.eval ( )
        val_loss =  0
        val_correct =  0
        val_batch_count =  0 with  torch. no_grad ():                
            for  batch  in val_loader :
                # skip empty batches
                iflen(batch['drug1_img']) == 0:
                    continue                                    val_batch_count += 1                                try:
                    drug1_img = batch['drug1_img'].to(device)
                    drug1_text = batch['drug1_text']
                    drug1_name = batch['drug1_name']
                    drug2_img = batch['drug2_img'].to(device)
                    drug2_text = batch['drug2_text']
                    drug2_name = batch['drug2_name']
                    labels = batch['label'].to(device)                                        # forward pass - processing one pair at a time for clarity
                    batch_size = len(drug1_name)
                    outputs = torch.zeros(batch_size, 1, device=device)                                        for i inrange(batch_size):
                        output = model(
                            drug1_img[i].unsqueeze(0),
                            [drug1_text[i]],
                            drug1_name[i],
                            drug2_img[i].unsqueeze(0),
                            [drug2_text[i]],
                            drug2_name[i]
                        )
                        outputs[i] = output                                        # calculate loss
                    loss = criterion(outputs, labels.unsqueeze(1))
                    val_loss += loss.item()                                        # calculate accuracy
                    predictions = (outputs >= 0.5).float()
                    val_correct += (predictions == labels.unsqueeze(1)).sum().item()                                    except Exceptionase:
                    print(f"Error processing validation batch {val_batch_count}: {str(e)}")
                    continue                                avg_val_loss = val_loss / max(1, val_batch_count)
            val_acc = val_correct /  max ( 1 , val_batch_count *  4 ) #  Assuming  batch_size= 4 print (f 'Epoch {epoch+1}/{epochs}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}' )                 return  model                        

Prepare data and start training

The prepared dataset then generates pairs of drug examples — each with its molecule image and text summary — and splits them into training and validation sets for performance tracking. The data loader packages these multimodal examples (images, descriptions, and labels) into batches so that they can be smoothly fed into the model. The KG-guided prediction network is instantiated with dimensions derived from the graph’s node and edge types, ensuring that its layers align with the structure of the knowledge graph. Finally, the training loop runs for a fixed number of rounds, alternating between updating the model on the training data and measuring its accuracy on the validation set. This sequence completes the transition from data preparation to active, graph-driven learning.

# initialize dataset and model

ddi _dataset = DDIDataset(drug_ data _df, drug_ drug _interactions, medical_ kg, node _to_ idx)# split dataset into train and validation sets
train _size = int(0.8 \* len(ddi_dataset ))
val _size = len(ddi_dataset ) - train _size
train_
dataset, val _dataset = torch.utils.data.random_ split(ddi _dataset, [train_ size, val _size])# create data loaders
train_
loader = DataLoader(train _dataset, batch_ size=4, shuffle=True, collate _fn=custom_ collate _fn)
val_
loader = DataLoader(val _dataset, batch_ size=4, shuffle=False, collate _fn=custom_ collate _fn)# initialize the model with the DGL graph
num_
node _types = pyg_ graph.x.shape[1]
num _edge_ types = len(edge ​​_type_ to _idx)# initialize the KG-guided multimodal model
model = KGGuidedMultimodalModel(pyg_
graph, num _node_ types, num _edge_ types, node _to_ idx, idx _to_ node)# train the model
trained _model = train_ kg4mm _model(model, train_ loader, val _loader, epochs=5)

Reasoning and explanation

To make a prediction, the model first loads the processed images and text summaries for each drug and determines where each drug is located in the knowledge graph. It then produces a probability score showing how the visual, textual, and graph information work together. At the same time, the system checks the graph for any direct links between the two drugs, any proteins or diseases they are both connected to, and any simple paths connecting them that do not exceed a given length. This probability is converted into a low, medium, or high risk level. An explanation is then constructed that highlights known interaction mechanisms, shared targets, and key graph paths that guide decision making. Finally, the system provides exemplary clinical recommendations based on risk level, clearly demonstrating how the knowledge graph shapes predictions and their interpretations.

def predict_interaction ( model, drug1_name, drug2_name, drug_data_df, medical_kg ): 
    device = torch.device( 'cuda' if  torch.cuda.is_available()  else 'cpu' )
    model = model.to(device)
    model. eval ()         # get drug indices
    drug1_idx = drug_data_df[drug_data_df[ 'name' ] == drug1_name].index[ 0 ]
    drug2_idx = drug_data_df[drug_data_df[ 'name' ] == drug2_name].index[ 0 ]         # get drug data
    drug1_data = drug_data_df.iloc[drug1_idx]
    drug2_data = drug_data_df.iloc[drug2_idx]         # prepare images
    transform = transforms.Compose([
        transforms.Resize(( 224224 )),
        transforms.ToTensor(),
        transforms.Normalize(mean=[ 0.4850.4560.406 ], std=[ 0.2290.2240.225 ])
    ]) drug1_img = Image. open (drug1_data[ 'image_path' ]).convert( 'RGB' )
    drug1_img = transform(drug1_img).unsqueeze( 0 ).to(device)
    drug1_text = [drug1_data[ 'description' ]] drug2_img = Image.open (drug2_data[ 'image_path' ]).convert( ' RGB' )
    drug2_img = transform(drug2_img).unsqueeze( 0 ).to(device)
    drug2_text = [drug2_data[ 'description' ]]         # get knowledge subgraph for the drug pair
    knowledge = retrieve_knowledge_subgraph(medical_kg, drug1_name, drug2_name)         # make prediction
    with  torch.no_grad():
        interaction_prob = model(
            drug1_img,
            drug1_text,
            drug1_name,
            drug2_img,
            drug2_text,
            drug2_name
        )         return  interaction_prob.item(), knowledgedef explain_interaction_prediction(drug1_name, drug2_name, probability, knowledge):
    explanation =  f"KG-guided multimodal analysis for interaction between  {drug1_name}  and  {drug2_name} :\n\n" # interpret the probability        
    if  probability >  0.8 :
        risk_level =  "High"
    elif  probability >  0.5 :
        risk_level =  "Moderate"
    else :
        risk_level =  "Low"             explanation +=  f"Interaction Risk Level:  {risk_level}  (Probability:  {probability: .2 f} )\n\n" # explain based on knowledge graph structure        
    explanation +=  "Knowledge Graph Analysis:\n" if  knowledge[ 'direct_interaction' ]:        
        mechanism = knowledge[ 'direct_interaction' ].get( 'mechanism''unknown mechanism' )
        explanation +=  f"✓ Direct Connection: The knowledge graph contains a documented interaction between these drugs with  {mechanism} .\n\n" if  knowledge[ 'common_targets' ]:        
        explanation +=  "✓ Common Target Nodes: These drugs connect to shared entities in the knowledge graph:\n"
        for  target  in  knowledge[ 'common_targets' ]:
            explanation +=  f" -  {target} \n"
        explanation +=  " This graph structure suggests potential interaction through common binding sites or pathways.\n\n" if  knowledge[ 'paths'and len (knowledge[ 'paths' ]) >  0 :        
        explanation +=  "✓ Knowledge Graph Pathways: The model identified these connecting paths in the graph:\n"
        for  i, path  in enumerate (knowledge[ 'paths' ][: 3 ]):
            path_str =  " → " .join(path)
            explanation +=  f" - Path  {i+ 1 }{path_str} \n"
        explanation +=  " These graph structures guided the multimodal feature integration for prediction.\n\n" # focus on how KG structure guided the interpretation        
    explanation +=  "Multimodal Integration Process:\n"
    explanation +=  " - Knowledge graph structure determined which drug properties were most relevant\n"
    explanation +=  " - Graph neural networks analyzed the local neighborhood of both drug nodes\n"
    explanation +=  " - Node position in the graph guided the weighting of visual and textual features\n\n" # clinical implications (example - in a real system, this would be more comprehensive)        
    if  probability >  0.5 :
        explanation +=  "Clinical Recommendations (based on graph analysis):\n"
        explanation +=  " - Consider alternative medications not connected in similar graph patterns\n"
        explanation +=  " - If co-administration is necessary, monitor for interaction effects\n"
        explanation +=  " - Review other drugs connected to the same nodes for potential complications\n"
    else :
        explanation +=  "Clinical Recommendations (based on graph analysis):\n"
        explanation +=  " - Standard monitoring advised\n"
        explanation +=  " - The knowledge graph structure suggests minimal interaction concerns\n" return  explanation        

result

To illustrate the full workflow, two drugs are selected and their pre-generated images and text summaries are loaded and pre-processed as they were during training. These multimodal inputs are then passed through the trained model - now in evaluation mode - to produce a probability score quantifying their risk of interaction. Meanwhile, for visualization and interpretation, the process extracts the relevant part of the knowledge graph by collecting all direct connections, shared biological targets, and any simple paths connecting them that do not exceed a given length, and then augments this subgraph by adding a layer of direct neighbors to obtain broader context.

The extracted subgraphs are laid out with a clear color scheme that effectively distinguishes between two drugs, proteins, diseases, and other entities, making the network structure clear at a glance, enhancing readability and analysis efficiency. This is followed by a clear natural language explanation that associates probability scores with these graph features by highlighting any documented interaction mechanisms, shared targets, and key connection paths. Risk estimates, color-coded visualizations, and narrative explanations together illustrate how the topology of the knowledge graph guides the fusion of visual and textual signals and provides a transparent rationale for the model's predictions.

# example usage

drug _pair = ("Goserelin", "Desmopressin")
prob, knowledge = predict_
interaction(trained _model, drug_ pair[0], drug _pair[1], drug_ data _df, medical_ kg)print(f"Predicted interaction probability between {drug _pair[0]} and {drug_ pair[1]}: {prob:.4f}")print("\nKnowledge Graph Structure Analysis:")
print(f"Direct connection: {knowledge['direct _interaction']}")
print(f"Common target nodes: {knowledge['common_
targets']}")
print(f"Graph paths connecting drugs:")
for path in knowledge['paths']:
print(f" {' -> '.join(path)}")# visualize the subgraph for these drugs to show the KG-guided approach
plt.figure(figsize=(12, 8))
subgraph _nodes = set([drug_ pair[0], drug _pair[1]])

# add intermediate nodes in paths to highlight the KG structure

for path in knowledge['paths']:
subgraph_
nodes.update(path) # add a level of neighbors to show context in KG
neighbors _to_ add = set()
for node in subgraph _nodes:
if node in medical_
kg:
neighbors _to_ add.update(list(medical _kg.neighbors(node))[:3])
subgraph_
nodes.update(neighbors _to_ add)subgraph = medical _kg.subgraph(subgraph_ nodes)# use different colors for node types to emphasize KG structure
node_colors = []
for node in subgraph.nodes():
if node == drug_
pair[0] or node == drug _pair[1]:
node_
colors.append('lightcoral')
elif subgraph.nodes[node].get('type') == 'protein':
node _colors.append('lightblue')
elif subgraph.nodes[node].get('type') == 'disease':
node_
colors.append('lightgreen')
else:
node _colors.append('lightgray')pos = nx.spring_ layout(subgraph, seed=42)
nx.draw(subgraph, pos, with _labels=True, node_ color=node _colors,
node_
size=2000, arrows=True, arrowsize=20)edge _labels = {(s, o): subgraph[ s ][ o ][ 'relation' ] for s, o in subgraph.edges()}
nx.draw_
networkx _edge_ labels(subgraph, pos, edge _labels=edge_ labels)plt.title(f"Knowledge Graph Structure Guiding {drug _pair[0]} and {drug_ pair[1]} Interaction Analysis")
plt.savefig('kg _guided_ interaction _analysis.png')
plt.show()# show explanation
explanation = explain_
interaction _prediction(drug_ pair[0], drug _pair[1], prob, knowledge)
print(explanation)

When tested on Goserelin and Desmopressin, the model returned a probability of 0.54, classifying them as a medium risk pair. The knowledge graph revealed a direct "interacts_with" relationship between the two drugs, which was specifically described/labeled as "increases_anticoagulant_effect", with no shared protein or disease connections, so the model focused on that mechanism. In the subgraph, the two drugs are highlighted in red, and the single directed edge is highlighted, making it clear which relationship drives the prediction.

in conclusion

KG4MM research shows that putting a knowledge graph at the core of the workflow can better fuse molecular images and text than single-source approaches. Each prediction is supported by clear graph evidence—direct edges, shared targets, and connected pathways—that links the results to real biological relationships. By doing so, KG4MM provides stronger predictive power and built-in interpretability in fields such as biochemistry, materials science, and medical diagnostics.