Avec l'application et la promotion des modèles d'apprentissage en profondeur, les gens ont progressivement découvert que les modèles utilisent souvent de fausses corrélations (Spurious Correlation) dans les données pour obtenir des performances d'entraînement plus élevées. Cependant, comme de telles corrélations ne sont souvent pas vraies sur les données de test, les performances de test de tels modèles sont souvent insatisfaisantes [1]. L'essentiel est que l'objectif traditionnel d'apprentissage automatique (minimisation du risque empirique, ERM) suppose les caractéristiques de distribution indépendantes et identiques de l'ensemble de formation et de test, mais en réalité, les scénarios dans lesquels l'hypothèse de distribution indépendante et identique est vraie sont souvent limités. Dans de nombreux scénarios réels, la distribution des données de formation et la distribution des données de test montrent généralement des incohérences, c'est-à-dire des changements de distribution (Distribution Shifts). Le problème visant à améliorer les performances du modèle dans de tels scénarios est généralement évoqué. problème de généralisation hors distribution (généralisation hors distribution). Une classe de méthodes telles que l'ERM qui se concentrent sur l'apprentissage des corrélations plutôt que sur la causalité dans les données a souvent du mal à gérer les changements de distribution. Bien que de nombreuses méthodes aient émergé ces dernières années et aient fait certains progrès dans le problème de la non-distribution en utilisant le principe d'invariance dans l'inférence causale, la recherche sur les données graphiques est encore limitée. En effet, la généralisation hors distribution des données graphiques est plus difficile que celle des données européennes traditionnelles, ce qui pose davantage de défis à l'apprentissage automatique des graphiques. Cet article prend la tâche de classification des graphes comme exemple pour explorer la généralisation externe de la distribution des graphes basée sur le principe de l'invariance causale.
Ces dernières années, grâce au principe d'invariance causale, les gens ont obtenu un certain succès dans le problème de la généralisation hors distribution des données euclidiennes, mais la recherche sur les données graphiques est encore limitée. Contrairement aux données euclidiennes, la complexité des graphiques pose des défis uniques pour l'utilisation des principes d'invariance causale et pour surmonter les difficultés de généralisation hors distribution.
Pour relever ce défi, nous intégrons l'invariance causale dans l'apprentissage automatique des graphes dans ce travail et proposons un cadre d'apprentissage des graphes invariants d'inspiration causale, qui fournit une nouvelle méthode pour résoudre le problème de la généralisation hors distribution des données graphiques. .théories et méthodes.
L'article a été publié dans NeurIPS 2022. Ce travail a été réalisé en coopération avec l'Université chinoise de Hong Kong, l'Université baptiste de Hong Kong, Tencent AI Lab et l'Université de Sydney. Titre de l'article : Apprentissage des représentations causalement invariantes pour la généralisation hors distribution sur les graphiques
Code du projet : https://github.com/LFhase/CIGA
Premièrement, des changements de distribution des données du graphique peuvent apparaître dans la distribution des caractéristiques des nœuds du graphique (décalages au niveau des attributs). Par exemple, dans un système de recommandation, les produits impliqués dans les données de formation peuvent provenir de certaines catégories populaires, et les utilisateurs impliqués peuvent également provenir de certaines régions spécifiques. Cependant, pendant la phase de test, le système doit gérer correctement les utilisateurs de tous. catégories, régions et produits [2,3,4]. De plus, des changements de distribution des données du graphique peuvent également apparaître dans la distribution structurelle du graphique (Structure-level Shifts). Dès 2019, les gens ont remarqué qu'il était difficile d'apprendre aux réseaux de neurones graphiques formés sur des graphiques plus petits des poids d'attention (Attention) efficaces à généraliser à des graphiques plus grands [5], ce qui favorise également une série de travaux connexes ont été proposés [6,7]. Dans des scénarios réels, ces deux types de changements de distribution peuvent souvent apparaître en même temps, et ces changements de distribution à différents niveaux peuvent également présenter différents modèles de fausses corrélations avec les étiquettes à prédire. Par exemple, dans les systèmes de recommandation, les produits de catégories spécifiques et les utilisateurs de régions spécifiques présentent souvent des structures topologiques uniques sur le graphique d'interaction produit-utilisateur [4]. Dans la prédiction des attributs des molécules médicamenteuses, les molécules médicamenteuses impliquées dans la formation peuvent être trop petites et les résultats de la prédiction seront également affectés par l'environnement de mesure expérimental [8]. De plus, la généralisation hors distribution dans l'espace euclidien suppose souvent que les données proviennent de plusieurs environnements (Environnements) ou domaines (Domaine), et suppose en outre que pendant la formation, le modèle peut obtenir l'environnement dans lequel chaque échantillon dans les données de formation appartiennent, afin de découvrir l'invariance entre les environnements. Cependant, l’obtention d’étiquettes environnementales pour les données nécessite souvent certaines connaissances spécialisées liées aux données, et en raison de la nature abstraite des données graphiques, l’obtention d’étiquettes environnementales pour les données graphiques est plus coûteuse. Par conséquent, la plupart des ensembles de données graphiques existants tels que OGB ne contiennent pas de telles informations sur les étiquettes environnementales. Même si quelques ensembles de données, comme DrugOOD, ont des étiquettes environnementales, il existe différents degrés de bruit. Les méthodes existantes peuvent-elles résoudre le problème de la généralisation hors distribution sur les graphes ? Afin d'avoir une compréhension intuitive des défis de la généralisation hors distribution des données graphiques, nous construisons de nouvelles données basées sur l'ensemble de données Spurious-Motif [9] pour instancier davantage les défis ci-dessus, et essayons utiliser des méthodes existantes telles que la cible de formation IRM [10] pour la généralisation hors distribution sur des données européennes, ou GNN [11] avec des capacités d'expression plus fortes, analyser si les méthodes existantes peuvent résoudre le problème de la généralisation hors distribution des données européennes. données graphiques. Figure 2. Exemple d'ensemble de données Spurious Motif. La tâche Spurious Motif est illustrée à la figure 2. Elle juge principalement l'étiquette du graphique en fonction du fait que le graphique d'entrée contient un sous-graphe avec une structure spécifique (telle que Maison ou Cycle), où la couleur du nœud représente le attribut du nœud. L'utilisation de cet ensemble de données peut clairement tester l'impact des changements de distribution à différents niveaux sur les performances des réseaux de neurones graphiques. Pour un modèle GNN ordinaire formé à l'aide d'ERM : De plus, le modèle ne peut obtenir aucune information relative aux labels environnementaux pendant la formation, et les résultats expérimentaux sont présentés dans la figure 3 (plus de résultats peuvent être trouvés dans l'annexe D de l'article). Figure 3. Performances des méthodes existantes sous différents changements de distribution de graphiques. Comme le montre la figure 3, le GCN ordinaire, qu'il soit formé à l'aide d'ERM ou d'IRM, ne peut pas gérer le décalage structurel (Struc) du graphique lors de l'ajout du décalage d'attribut de nœud de graphique (Mixte) et du graphique après la distribution de taille ; est décalé (dans la figure 3), les performances du modèle seront encore réduites ; de plus, même en utilisant kGNN avec une puissance d'expression plus forte, il est difficile d'éviter de graves pertes de performances (performances moyennes réduites ou variance plus grande). De là, nous arrivons naturellement à la question à étudier : Comment obtenir un modèle GNN capable de faire face à divers changements de distribution de graphes ? Afin de résoudre les problèmes ci-dessus, nous devons définir l'objectif d'apprentissage, c'est-à-dire le réseau neuronal du graphe invariant (Invariant GNN), c'est-à-dire qu'il peut toujours fonctionner dans le pire environnement Bon modèle (voir l'article pour une définition rigoureuse) : Définition 1 (réseau neuronal à graphes invariants) Étant donné une série d'ensembles de données de classification de graphes collectés à partir de différents environnements causalement liés , où contient des échantillons indépendants et distribués de manière identique qui sont considérés comme provenant de l'environnement e. Considérons un réseau neuronal graphique , où et sont respectivement l'espace graphique et l'espace échantillon. en entrée, f est un réseau neuronal graphique invariant, si et seulement si , c'est-à-dire minimiser le pire risque empirique dans tous les environnements, où est la perte empirique du modèle dans l'environnement. Le modèle ne peut obtenir qu'une partie des données dans l'environnement de formation pendant la formation Si aucune hypothèse n'est faite sur le processus de données, l'optimalité minmax requise par la définition du réseau neuronal du graphe invariant est difficile à atteindre. réaliser. Par conséquent, nous utilisons un modèle causal structurel pour modéliser le processus de génération de graphiques du point de vue de l'inférence causale et caractériser la corrélation entre les environnements dans le but de définir l'invariance causale sur les données graphiques. Figure 4. Modèle causal du processus de génération de données graphiques. Sans perte de généralité, nous incorporons toutes les variables latentes qui affectent la génération de graphiques dans l'espace latent et modélisons le processus de génération de graphiques comme . De plus, pour la variable latente , selon qu'elle est affectée ou non par l'environnement E, on la divise en une variable latente invariante (variable latente invariante) et une variable latente parasite (variable latente parasite) . De manière correspondante, les variables latentes C et S affecteront respectivement la génération d'un certain sous-graphe de G, qui sont respectivement enregistrés comme le sous-graphe invariant et le faux sous-graphe , comme le montre la figure 4 (a), et C contrôle principalement l'étiquette Y du graphique. Cela peut également être dérivé davantage , c'est-à-dire que C et Y ont des informations mutuelles plus élevées que S. Ce processus de génération correspond à de nombreux exemples pratiques. Par exemple, les propriétés médicinales d'une molécule sont généralement déterminées par un certain groupe clé (sous-graphe moléculaire) (comme la solubilité dans l'eau de l'hydroxyle-HO par rapport à la molécule). De plus, C a de nombreux types d'interactions avec Y, S et E dans l'espace latent. Il s'ensuit principalement si la fausse variable latente S et l'étiquette Y ont des associations supplémentaires en plus de la variable latente constante C, c'est-à-dire , Il peut être résumé en deux types : FIIF (Fonction invariante entièrement informative) comme le montre la figure 4 (b) et PIIF (Fonction invariante partiellement informative) comme le montre la figure 4 (c). Parmi eux, FIIF signifie que l'étiquette est indépendante du montant de fausse corrélation étant donné les informations invariantes. PIIF est le contraire. Il convient de noter que afin de couvrir autant de changements de distribution de graphiques que possible, notre modèle causal s'efforce de modéliser largement divers modèles de génération de graphiques. Étant donné plus de connaissances sur le processus de génération de graphiques, le modèle causal présenté dans la figure 4 peut être généralisé à des exemples plus spécifiques. Comme dans l'Annexe C.1, nous montrons comment les graphes causals peuvent être généralisés aux travaux antérieurs de Bevilacqua et al. [7] sur l'analyse des changements de distribution de la taille des graphes en ajoutant l'hypothèse d'une limite de graphe supplémentaire (graphon). Sur la base de l'analyse causale ci-dessus, nous pouvons savoir que lorsque le modèle utilise uniquement des sous-graphes invariants pour la prédiction, c'est-à-dire qu'il utilise uniquement la corrélation entre , la prédiction du modèle ne sera pas affectée par les changements dans le environnement E Impact ; Au contraire, si la prédiction du modèle repose sur des informations liées à S ou , ses résultats de prédiction seront considérablement modifiés en raison du changement de E, entraînant une perte de performances. Par conséquent, notre objectif peut être affiné davantage depuis l'apprentissage d'un réseau neuronal à graphes invariants pour : a) identifier des sous-graphes invariants potentiels b) prédire Y à l'aide des sous-graphes identifiés ; Afin de mieux correspondre au processus algorithmique de génération de données, nous divisons en outre le réseau neuronal du graphe en un réseau de reconnaissance de sous-graphes (Featurizer GNN) et un réseau de classification (Classifier GNN) , et , où est l'espace sous-graphe de . Ensuite, l'objectif d'apprentissage du modèle peut être exprimé comme indiqué dans la formule (1) : Parmi eux, est la prédiction du sous-graphe invariant par le réseau de reconnaissance de sous-graphes est # ; 🎜🎜# les informations mutuelles avec Y, généralement, en maximisant peuvent être utilisées en minimisant Prédire le réalisation empirique des pertes de Y. Cependant, en raison du manque de E, il nous est difficile d'utiliser directement E pour vérifier l'indépendance de . À cette fin, nous devons rechercher d'autres conditions d'équivalence pour identifier l'invariant requis. sous-graphiques. Apprentissage de graphes invariants d'inspiration causale . Dans de telles conditions, pensez à maximiser , bien que ait le même effet que Le taille du sous-graphe invariant estimé contient de faux sous-graphes qui ont des informations mutuelles avec Y. Afin d'"éliminer" les éventuels faux sous-graphes dans , nous rechercherons davantage à partir du modèle causal Plus sur les attributs uniques de . Notez que, quel que soit le type de fausse corrélation PIIF ou FIIF, pour le sous-graphe qui maximise les informations mutuelles avec l'étiquette Y, nous avons : En combinant les deux propriétés ci-dessus, nous pouvons déduire Comme il nous est difficile de l'observer directement dans la pratique, nous pouvons l'utiliser comme proxy dans la formule (2) . En même temps, lorsque et sont maximisés en même temps, sera automatiquement minimisé, sinon les prédictions du modèle s'effondreront en une solution triviale. À partir de là, nous avons obtenu la condition d'équivalence de sous-graphe invariant dans un cas simple. En combinaison avec la formule (1), nous avons obtenu la première version du cadre d'apprentissage de graphe invariant inspiré de la causalité (apprentissage de graphe invariant inspiré de la causalité), à savoir CIGAv1 : Parmi eux, et , soit et G sont issus de la même catégorie Y. Dans notre article, nous démontrons en outre que CIGAv1 peut identifier avec succès des sous-graphes invariants potentiels dans le modèle causal correspondant à la figure 4 lorsque la taille du graphique est connue. Cependant, comme les hypothèses précédentes sont trop idéales, en pratique, la taille du sous-graphe invariant peut changer et la taille correspondante est souvent inconnue. En supposant qu'il n'y ait pas de taille de sous-graphe, les exigences CIGAv1 peuvent être satisfaites en identifiant simplement le graphe entier comme un sous-graphe invariant. Par conséquent, nous envisageons de rechercher davantage de propriétés sur les sous-graphes invariants pour supprimer cette hypothèse. a remarqué qu'en maximisant , peut apparaître #🎜🎜 ##🎜 🎜# est supprimée Partage de parties de sous-graphe invariant les mêmes informations mutuelles et associées. Alors, pouvons-nous faire le contraire et maximiser pour supprimer d'éventuelles fausses intrigues secondaires de ? La réponse est oui, on peut utiliser la corrélation entre et Y pour la faire concurrencer l'estimation de . Il convient de noter que lors de la maximisation de , vous devez vous assurer que ne dépassera pas #🎜 🎜# , sinon prédira et tombera dans une solution triviale. Combiné avec cette condition supplémentaire, nous pouvons supprimer l'hypothèse sur la taille constante du sous-graphe de la formule (3) et obtenir le CIGAv2 suivant : # Figure 5. Schéma du cadre d'apprentissage des graphes invariants d'inspiration causale. Mise en œuvre de CIGA : En pratique, il est souvent difficile d'estimer l'information mutuelle de deux sous-graphes, et l'apprentissage contrastif supervisé [ 11 ] propose une solution réalisable : où correspond à l'échantillon positif dans la formule (4), et est la représentation graphique correspondant à . Lorsque , la formule (5) fournit un estimateur d'entropie de resubstituation non paramétrique (estimateur d'entropie de resubstituation non paramétrique) basé sur la densité du noyau de von Mises-Fisher pour [13,14]. La mise en œuvre finale de la partie centrale de CIGA est illustrée à la figure 5, c'est-à-dire en rapprochant la représentation graphique de la même catégorie de sous-graphes invariants dans l'espace de représentation latente, et en maximisant en même temps la représentation graphique des différentes catégories de sous-graphes invariants. sous-graphes invariants pour maximiser . De plus, pour une autre contrainte dans la formule (4), nous pouvons l'implémenter grâce à l'idée de perte charnière, c'est-à-dire , qui optimise uniquement les faux sous-graphes dont la perte empirique lors de la prédiction est supérieure au sous-graphe invariant correspondant. Dans les expériences, nous avons utilisé 16 ensembles de données synthétiques ou réels pour valider pleinement CIGA sous différents changements de distribution graphique. Dans l'expérience, nous avons implémenté le prototype de CIGA en utilisant le framework GNN interprétable [9], mais en fait CIGA a plus de moyens de l'implémenter. Pour des ensembles de données spécifiques et des détails expérimentaux, veuillez consulter la section expérimentale de l'article. Performance du changement de distribution de structure et du changement de distribution mixte dans l'ensemble de données synthétiques Nous avons d'abord construit des ensembles de données SPMotif-Struc et SPMotif-Mixed basés sur l'ensemble de données SPMotif [9], où SPMotif-Struc inclut de fausses corrélations entre des sous-graphiques spécifiques et d'autres structures de sous-graphiques dans le graphique, ainsi que des décalages de distribution des tailles de graphique tandis que SPMotif-Mixed ajoute des décalages de distribution au niveau des attributs du nœud du graphique en fonction de SPMotif-Struc. La première colonne du tableau est la base de référence de l'ERM et du GNN interprétable, et la deuxième colonne est l'algorithme de généralisation hors distribution le plus avancé dans l'espace euclidien. Les résultats montrent que le meilleur cadre GNN et l'algorithme de généralisation hors distribution dans l'espace euclidien sont soumis aux changements de distribution sur le graphique, et que lorsque davantage de changements de distribution se produisent, la perte de performance (performance de classification moyenne plus faible ou une plus grande variance) sera encore améliorée. En revanche, CIGA maintient de bonnes performances sous des changements de distribution de différentes forces et dépasse largement les meilleures performances de base. Performance de divers changements de distribution de graphiques sur des ensembles de données réels Nous avons ensuite testé les performances de CIGA sur des ensembles de données réels et des changements de distribution de graphiques qui existent dans diverses données réelles. Il comprend trois données. ensembles de trois divisions d'environnement différentes (analyse de l'environnement expérimental, échafaudage du squelette moléculaire et taille moléculaire) dans DrugOOD pour la prédiction des attributs des molécules médicamenteuses assistée par l'IA dans les produits pharmaceutiques assistés par l'IA, y compris les changements de distribution des graphiques dans divers scénarios d'application réels ; converti sur la base de l'ensemble de données d'image classique ColoredMNIST [10] dans l'espace euclidien comprend principalement le décalage de distribution de type PIIF des attributs de nœud graphique ; le Graph-SST5 et Twitter convertis sur la base de l'ensemble de données de classification des émotions en langage naturel SST5 et Twitter [15], et a également ajouté un changement de distribution du degré du graphique. De plus, nous avons également utilisé 4 ensembles de données de décalage de distribution de taille de graphe moléculaire précédemment étudiés [7], Les résultats des tests sont présentés dans le tableau ci-dessus. On constate que dans les données réelles, en raison de l'augmentation de la difficulté de la tâche, les performances du modèle obtenues en utilisant un GNN mieux architecturé ou hors de. -La formation des cibles d'optimisation de généralisation de distribution dans l'espace euclidien est encore plus faible que le modèle GNN ordinaire formé à l'aide d'ERM. Ce phénomène est également similaire au phénomène observé dans les expériences de généralisation hors distribution sous des tâches plus difficiles dans l'espace euclidien [16], reflétant la difficulté de la généralisation hors distribution sur des données réelles et les lacunes des méthodes existantes. En revanche, CIGA peut améliorer tous les changements de distribution de données et de graphiques réels, et même atteindre le niveau Oracle empiriquement optimal dans certains ensembles de données tels que Twitter et PROTEINS. Des tests préliminaires sur le dernier test de référence de test de généralisation de graphiques hors distribution BON sur l'ensemble de données de classification de graphiques montrent également que CIGA est actuellement le meilleur algorithme de généralisation de graphiques hors distribution capable de faire face à divers changements de distribution de graphiques. En raison de l'utilisation de GNN interprétables comme architecture de mise en œuvre du prototype de CIGA, nous avons également visualisé le DrugOOD identifié par le modèle et avons constaté que CIGA a trouvé des groupes moléculaires relativement cohérents pour la prédiction des attributs moléculaires. Cela peut fournir une meilleure base pour les produits pharmaceutiques ultérieurs assistés par l’IA. Figure 6. Sous-graphe partiellement invariant identifié par CIGA dans DrugOOD. À travers la perspective de l'inférence causale, cet article introduit pour la première fois l'invariance causale au problème de distribution de graphes hors généralisation sous divers changements de distribution de graphes, et propose une nouvelle solution théoriquement garantie le cadre CIGA. Un grand nombre d'expériences ont également pleinement vérifié les excellentes performances de généralisation hors distribution de CIGA. En regardant vers l'avenir, sur la base de CIGA, nous pouvons explorer davantage de meilleurs cadres de mise en œuvre [17], ou introduire de meilleures méthodes d'amélioration des données théoriquement garanties pour CIGA [3,18], et modéliser théoriquement l'association sur le graphique (décalage variable). ) [19] pour améliorer encore la capacité du CIGA à identifier les sous-graphes invariants et promouvoir la mise en œuvre réelle de réseaux neuronaux graphiques dans des scénarios d'application réels tels que les produits pharmaceutiques assistés par l'IA.
Modèle causal pour la généralisation en dehors de la distribution de données graphiques
Expériences et discussions
Résumé et perspectives
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!