The data starvation problem in the field of graph learning has new tricks that can alleviate it!
OpenGraph, a basic graph-based model specifically designed for zero-shot prediction on a variety of graph datasets.
Chao Huang’s team, head of the Data Intelligence Laboratory at the University of Hong Kong, also proposed improvement and adjustment techniques for the model to improve the model’s adaptability to new tasks.
Currently, this work has been posted on GitHub.
Introducing data augmentation technology, this work mainly explores in-depth strategies to enhance the generalization ability of graph models (especially when there are significant differences between training and test data).
OpenGraph is a general graph structure model that performs forward propagation through propagation prediction to achieve zero-sample prediction of new data.
In order to achieve the goal, the team solved the following 3 challenges:
Through a series of innovative methods, such as topology-aware BERT Tokenizer and anchor-based graph Transformer, OpenGraph effectively addresses the above challenges. Test results on multiple data sets demonstrate the model's excellent generalization ability and enable effective evaluation of the model's color generalization ability.
The OpenGraph model architecture mainly consists of 3 core parts:
First let’s talk about the unified graph Tokenizer.
In order to adapt to the differences in nodes and edges in different data sets, the team developed a unified graph Tokenizer, which normalizes graph data into a token sequence.
This process includes high-order adjacency matrix smoothing and topology-aware mapping.
High-order adjacency matrix smoothing uses the high-order power of the adjacency matrix to solve the problem of sparse connections, while topology-aware mapping converts the adjacency matrix into a node sequence and uses fast singular value decomposition (SVD) Minimize information loss and retain more graph structure information.
The second is the extensible graph Transformer.
After tokenization, OpenGraph uses the Transformer architecture to simulate the dependencies between nodes, and mainly uses the following technologies to optimize model performance and efficiency:
First, token sequence sampling, reducing model needs through sampling technology The number of relations processed thereby reduces the time and space complexity of training.
The second is the self-attention mechanism of anchor point sampling. This method further reduces the computational complexity and effectively improves the training efficiency and stability of the model through the information transfer between learning nodes in stages.
The last step is knowledge distillation of large language models.
In order to deal with the data privacy and category diversity issues faced when training general graph models, the team drew inspiration from the knowledge and understanding capabilities of large language models (LLM) and used LLM to generate various graph structure data.
This data enhancement mechanism effectively improves the quality and practicality of data by simulating the characteristics of real-world graphs.
The team also first generated a set of nodes adapted to the specific application, with each node having a textual description in order to generate edges.
When faced with large-scale node sets such as e-commerce platforms, researchers deal with this by subdividing nodes into more specific subcategories.
For example, from "electronic products" to specific "mobile phones", "laptops", etc., this process is repeated until the nodes are refined enough to be close to real instances.
The prompt tree algorithm subdivides nodes according to the tree structure and generates more detailed entities.
Start from a general category such as "product", gradually refine it to specific subcategories, and finally form a node tree.
As for edge generation, using Gibbs sampling, researchers form edges based on the generated set of nodes.
In order to reduce the computational burden, we do not directly traverse all possible edges through LLM. Instead, we first use LLM to calculate the text similarity between nodes, and then use a simple algorithm to determine the node relationship.
On this basis, the team introduced several technical adjustments:
The above steps ensure that the generated graph data is not only rich and diverse, but also close to the connection patterns and structural characteristics of the real world.
It should be noted that this experiment focuses on training the OpenGraph model using a data set generated only by LLM, and testing it on a diverse real-life scenario data set, covering Node classification and link prediction tasks.
The experimental design is as follows:
Zero sample setting.
To evaluate OpenGraph's performance on unseen data, we train the model on a generated training set and then evaluate it on a completely different real-world test set. It ensures that the training and testing data have no overlap in nodes, edges and features.
Few sample settings.
Considering that it is difficult for many methods to effectively perform zero-sample prediction, we introduce a few-sample setting. After the baseline model is pre-trained on pre-training data, k-shot samples are used for fine-tuning.
Results on 2 tasks and 8 test sets show that OpenGraph significantly outperforms existing methods in zero-shot prediction.
Additionally, existing pre-trained models sometimes perform worse than models trained from scratch on cross-dataset tasks.
At the same time, the team explored how the design of graph Tokenizer affects model performance.
First of all, it was found through experiments that not smoothing the adjacency matrix (the smoothing order is 0) will significantly reduce the performance, indicating the necessity of smoothing.
The researchers then tried several simple topology-aware alternatives: one-hot encoded IDs across datasets, random mapping, and node degree-based representations.
Experimental results show that the performance of these alternatives is not ideal.
Specifically, ID representation across data sets is the worst, degree-based representation also performs poorly, and random mapping, although slightly better, has a significant performance gap compared with optimized topology-aware mapping.
The team investigated the impact of different pre-training datasets on OpenGraph performance, including those generated using LLM-based knowledge distillation methods dataset, as well as several real datasets.
The pre-training data sets compared in the experiment include the data set after removing a certain technology from the team generation method, and 2 real data sets that have nothing to do with the test data set (Yelp2018 and Gowalla), 1 real data set (ML-10M) similar to the test data set.
The experimental results show that the generated data set shows good performance on all test sets; the removal of the three generation techniques significantly affects the performance, verifying the effectiveness of these techniques.
When training using real datasets that are independent of the test set (such as Yelp and Gowalla) Performance sometimes degrades, possibly due to distribution differences between different datasets.
The ML-10M dataset achieves the best performance on similar test datasets (such as ML-1M and ML-10M) , highlighting the similarity between the training and test datasets The importance of sex.
In this part of the experiment, the research team explored two sampling techniques used in the graph Transformer module:
Token sequence sampling (Seq) and anchor sampling (Anc).
They conducted detailed ablation experiments on these two sampling methods to evaluate their specific impact on model performance.
Experimental results show that whether it is token sequence sampling or anchor point sampling, both can effectively reduce the space and time complexity of the model during the training and testing phases. This is especially important for processing large-scale graph data and can significantly improve efficiency.
From a performance perspective, token sequence sampling has a positive impact on the overall performance of the model. This sampling strategy optimizes the representation of the graph by selecting key tokens, thereby improving the model's ability to handle complex graph structures.
In contrast, experiments on the ddi dataset show that anchor sampling may have a negative impact on model performance. Anchor sampling simplifies the graph structure by selecting specific nodes as anchor points, but this method may ignore some key graph structure information, thus affecting the accuracy of the model.
In summary, although both sampling techniques have their advantages, in practical applications, the appropriate sampling strategy needs to be carefully selected based on specific data sets and task requirements.
This research aims to develop a highly adaptable framework that can accurately identify and parse complex topological patterns of various graph structures.
The researchers' goal is to significantly enhance the model's generalization ability in zero-shot graph learning tasks, including a variety of downstream applications, by fully leveraging the capabilities of the proposed model.
The model is built with the support of a scalable graph Transformer architecture and LLM-enhanced data augmentation mechanism to improve the efficiency and robustness of OpenGraph.
Through extensive testing on multiple standard datasets, the team demonstrated the model’s excellent generalization performance.
It is understood that as an initial attempt to build a graph-based model, in the future, the team's work will focus on increasing the automation capabilities of the framework, including automatically identifying noisy connections and conducting counterfactuals study.
At the same time, the team plans to learn and extract common and transferable patterns of various graph structures to further promote the application scope and effect of the model.
Reference link:
[1] Paper: https://arxiv.org/pdf/2403.01121.pdf.
[2] Source code library: https://github.com/HKUDS/OpenGraph.
The above is the detailed content of HKU's large open source graph basic model OpenGraph: strong generalization ability, forward propagation to predict new data. For more information, please follow other related articles on the PHP Chinese website!