{"title":"Enhancing Out-of-distribution Generalization on Graphs via Causal Attention Learning","authors":"Yongduo Sui, Wenyu Mao, Shuyao Wang, Xiang Wang, Jiancan Wu, Xiangnan He, Tat-Seng Chua","doi":"10.1145/3644392","DOIUrl":null,"url":null,"abstract":"<p>In graph classification, attention- and pooling-based graph neural networks (GNNs) predominate to extract salient features from the input graph and support the prediction. They mostly follow the paradigm of “learning to attend”, which maximizes the mutual information between the attended graph and the ground-truth label. However, this paradigm causes GNN classifiers to indiscriminately absorb all statistical correlations between input features and labels in the training data, without distinguishing the causal and noncausal effects of features. Rather than emphasizing causal features, the attended graphs tend to rely on noncausal features as shortcuts to predictions. These shortcut features may easily change outside the training distribution, thereby leading to poor generalization for GNN classifiers. In this paper, we take a causal view on GNN modeling. Under our causal assumption, the shortcut feature serves as a confounder between the causal feature and prediction. It misleads the classifier into learning spurious correlations that facilitate prediction in in-distribution (ID) test evaluation, while causing significant performance drop in out-of-distribution (OOD) test data. To address this issue, we employ the backdoor adjustment from causal theory — combining each causal feature with various shortcut features, to identify causal patterns and mitigate the confounding effect. Specifically, we employ attention modules to estimate the causal and shortcut features of the input graph. Then, a memory bank collects the estimated shortcut features, enhancing the diversity of shortcut features for combination. Simultaneously, we apply the prototype strategy to improve the consistency of intra-class causal features. We term our method as CAL+, which can promote stable relationships between causal estimation and prediction, regardless of distribution changes. Extensive experiments on synthetic and real-world OOD benchmarks demonstrate our method’s effectiveness in improving OOD generalization. Our codes are released at https://github.com/shuyao-wang/CAL-plus.</p>","PeriodicalId":49249,"journal":{"name":"ACM Transactions on Knowledge Discovery from Data","volume":"18 1","pages":""},"PeriodicalIF":4.0000,"publicationDate":"2024-02-05","publicationTypes":"Journal Article","fieldsOfStudy":null,"isOpenAccess":false,"openAccessPdf":"","citationCount":"0","resultStr":null,"platform":"Semanticscholar","paperid":null,"PeriodicalName":"ACM Transactions on Knowledge Discovery from Data","FirstCategoryId":"94","ListUrlMain":"https://doi.org/10.1145/3644392","RegionNum":3,"RegionCategory":"计算机科学","ArticlePicture":[],"TitleCN":null,"AbstractTextCN":null,"PMCID":null,"EPubDate":"","PubModel":"","JCR":"Q1","JCRName":"COMPUTER SCIENCE, INFORMATION SYSTEMS","Score":null,"Total":0}
引用次数: 0
Abstract
In graph classification, attention- and pooling-based graph neural networks (GNNs) predominate to extract salient features from the input graph and support the prediction. They mostly follow the paradigm of “learning to attend”, which maximizes the mutual information between the attended graph and the ground-truth label. However, this paradigm causes GNN classifiers to indiscriminately absorb all statistical correlations between input features and labels in the training data, without distinguishing the causal and noncausal effects of features. Rather than emphasizing causal features, the attended graphs tend to rely on noncausal features as shortcuts to predictions. These shortcut features may easily change outside the training distribution, thereby leading to poor generalization for GNN classifiers. In this paper, we take a causal view on GNN modeling. Under our causal assumption, the shortcut feature serves as a confounder between the causal feature and prediction. It misleads the classifier into learning spurious correlations that facilitate prediction in in-distribution (ID) test evaluation, while causing significant performance drop in out-of-distribution (OOD) test data. To address this issue, we employ the backdoor adjustment from causal theory — combining each causal feature with various shortcut features, to identify causal patterns and mitigate the confounding effect. Specifically, we employ attention modules to estimate the causal and shortcut features of the input graph. Then, a memory bank collects the estimated shortcut features, enhancing the diversity of shortcut features for combination. Simultaneously, we apply the prototype strategy to improve the consistency of intra-class causal features. We term our method as CAL+, which can promote stable relationships between causal estimation and prediction, regardless of distribution changes. Extensive experiments on synthetic and real-world OOD benchmarks demonstrate our method’s effectiveness in improving OOD generalization. Our codes are released at https://github.com/shuyao-wang/CAL-plus.
期刊介绍:
TKDD welcomes papers on a full range of research in the knowledge discovery and analysis of diverse forms of data. Such subjects include, but are not limited to: scalable and effective algorithms for data mining and big data analysis, mining brain networks, mining data streams, mining multi-media data, mining high-dimensional data, mining text, Web, and semi-structured data, mining spatial and temporal data, data mining for community generation, social network analysis, and graph structured data, security and privacy issues in data mining, visual, interactive and online data mining, pre-processing and post-processing for data mining, robust and scalable statistical methods, data mining languages, foundations of data mining, KDD framework and process, and novel applications and infrastructures exploiting data mining technology including massively parallel processing and cloud computing platforms. TKDD encourages papers that explore the above subjects in the context of large distributed networks of computers, parallel or multiprocessing computers, or new data devices. TKDD also encourages papers that describe emerging data mining applications that cannot be satisfied by the current data mining technology.