Графовые нейронные сети (GNN)! Представьте их сватами вселенной ИИ, неустанно помогающими точкам данных находить друзей и популярность, исследуя их связи. Лучший ведомый на цифровой вечеринке.
Вы спросите, почему эти GNN так важны? Ну, в реальной жизни вроде бы все взаимосвязано. Мы говорим о таких вещах, как социальные сети, Всемирная паутина, сети частиц и даже изоморфный танец молекул ⚛ (спросите Уолтера Уайта). А вот и ошеломляющее открытие: даже «прямолинейные» структуры данных, такие как тексты, изображения и табличные форматы, можно представить в виде графиков!🧲. Это все равно, что превратить ваши скучные вечеринки с данными, достойными повторения, в моменты озарения! Поверьте, возможности безграничны.
Но подождите, а что за набор этих новых парней из толпы ИИ? Ну, они как классные родственники сверточных и последовательных моделей машинного обучения (ML). Их архитектура вдохновлена словом flexible
, поэтому они скручиваются в воронкообразный пирог (если воронкообразные пироги были структурами данных) только для того, чтобы расшифровывать сложные отношения и решать проблемы, вызывающие зависть даже у Шерлока 🕵.
Содержание
В этой статье мы поговорим об основах графовых структур данных и основанных на графах архитектурах машинного обучения. Подробные объяснения выходят за рамки этой работы, и я предоставил полезные ссылки везде, где это возможно. Кроме того, мы будем создавать некоторые модели с использованием PyTorch Geometric (PyG) (наш плащ Супермена) со следующей дорожной картой:
- Краткий обзор наборов графических данных и представление набора данных Planetoid. Кроме того, мы определим нашу постановку задачи ML прямо здесь.
- Разверните архитектуру GNN и несколько умных формул из воздуха.
- Нет, мы не пропускаем занятия! Поэтому требуется введение в модели PyTorch с адаптированными классами Python.
- Затем мы обучаем модели и тестируем наши творения. Окончательное противостояние, в котором наши GNN будут сражаться с набором данных.
- Подведение итогов и основные выводы.
Пристегнитесь, это будет увлекательное путешествие! 🚀📊
Хорошо, давайте поговорим о наборах графических данных — цифровой игровой площадке, где точки данных тусуются, делятся историями, а иногда даже сплетничают. Представьте их как те взаимосвязанные социальные круги, которые вы найдете на вечеринках, но вместо людей у вас есть узлы, и информация делится между ними через ребра. Теперь узлы и ребра не просто стоят и показывают свои виртуальные большие пальцы вверх 👍 вниз 👎. они - ⭐ шоу, каждый из которых имеет свой набор функций и атрибутов.
Но подождите, мы не собираемся плести все это с нуля. Нет, мы не настолько амбициозны. Давайте поприветствуем пакет Planetoid от PyTorch Geometric и уменьшим шаблон. Это похоже на план построения графика вашей мечты без особых усилий. Набор Лего набор, позволяющий исследователям управлять размером графика, его соединениями и разделением данных.
Cora, классический эталонный набор данных сети цитирования из статьи «Пересмотр полууправляемого обучения с помощью графических вложений». В этом наборе данных каждая исследовательская работа является узлом, а ребра? Ах, они словно невидимые нити, соединяющие статьи через цитаты 📚🤓
Теперь каждый из этих бумажных гостей приходит с подарками — в частности, с пакетом словесных представлений о его содержании. Это праздник словарного запаса, где каждый вектор признаков узла показывает наличие (1) или отсутствие (0) определенного слова из общего числа 1433 вариантов. И позвольте мне сказать вам, что эти газеты — острые едоки; они заботятся только об определенных словах.
В мире науки Cora — лучший выбор для оценки GNN и других методов в таких задачах, как классификация узлов и прогнозирование ссылок. И помните, что на этой вечеринке цитаты (edges
) — лучший ледокол! ➡️
Восторги Коры
x=[2708, 1433]
– это матрица признаков узла. Представьте себе: имеется 2708 документов, и каждый из них представлен 1433-мерным вектором признаков, все они закодированы в горячем режиме.edge_index=[2, 10556]
представляет собой связность графа. Это говорит о том, кто с кем тусуется с формой (2, количество направленных ребер). 📩y=[2708]
– это ярлык достоверности. Каждая нода отнесена к одному классу ровно, чтобы не было неловкого момента — «Ну и что ты исследуешь? 😆train_mask[2708]
,val_mask[2708]
,test_mask[2708]
— это необязательные атрибуты, которые помогают разделить набор данных на наборы для обучения, проверки и тестирования соответственно. Булевы значения, присутствующие в них, утверждают, что правильные узлы смешиваются в правильных местах.
Давайте остановимся на мгновение, чтобы подумать. С вектором признаков из 1433 слов можно легко использовать модель MLP 👷 для какой-то старой доброй классификации узлов/документов. Но эй, мы не из тех, кто соглашается на обычное 🔎. Мы собираемся пересечь границу с
edge_index
до 🤾 и нырнуть с головой в эти отношения, чтобы усилить наши прогнозы. Итак, давайте серьезно подключимся здесь! 🤝
# Let us talk more about edge index/graph connectivity print(f"Shape of graph connectivity: {cora[0].edge_index.shape}") print(cora[0].edge_index)
edge_index
интересен тем, что содержит два списка, в первом из которых шепчутся идентификаторы узлов-источников, а во втором раскрываются сведения об их пунктах назначения. У этой настройки причудливое название: список координат (COO). Это отличный способ эффективно хранить разреженные матрицы, например, когда у вас есть узлы, которые не совсем болтают со всеми в комнате.
Теперь я знаю, о чем вы думаете. Почему бы не использовать простую матрицу смежности? Что ж, в области графических данных не каждый узел является социальной бабочкой. Эти матрицы смежности? Они будут плавать в море нулей, а это не самая эффективная конфигурация памяти. Вот почему главный операционный директор — наш основной подход 🧩, а PyG гарантирует, что края изначально направлены.
# The adjacency matrix can be inferred from the edge_index with a utility function. adj_matrix = torch_geometric.utils.to_dense_adj(cora[0].edge_index)[0].numpy().astype(int) print(f'Shape: {adj_matrix.shape}\nAdjacency matrix: \n{adj_matrix}') # Some more PyG utility functions print(f"Directed: {cora[0].is_directed()}") print(f"Isolated Nodes: {cora[0].has_isolated_nodes()}") print(f"Has Self Loops: {cora[0].has_self_loops()}")
Объект Data
имеет множество замечательных полезных функций, и давайте рассмотрим три примера:
is_directed
сообщает, является ли граф ориентированным, т. е. матрица смежности несимметрична.has_isolated_edges
выявляет узлы-одиночки, оторванные от суетливой толпы. Эти разъединенные души подобны кусочкам головоломки без полной картины, что делает последующие задачи машинного обучения настоящей головной болью.has_self_loops
сообщает, находится ли узел в отношениях сам с собой ❣
Давайте кратко поговорим о визуализации. Преобразование объектов PyG
Data
в объектыNetworkX
графа и их построение — это простая прогулка. Но, придержите лошадей! Наш список гостей (количество узлов) имеет длину более 2 тыс., поэтому попытка визуализировать его будет подобна втискиванию футбольного стадиона в вашу гостиную. Да, ты не хочешь этого ⛔. Так что, пока мы не участвуем в сюжетных вечеринках, просто знайте, что этот граф заряжен и готов к серьезным сетевым действиям, даже если все это происходит за кадром 🌐🕵️♀️
CiteSeer – это научный 🎓 брат Коры из семейства Платеноидов. Он стоит на сцене с 3327 научными статьями, где каждый узел имеет ровно одну из 6 элитных категорий (меток классов). Теперь давайте поговорим о статистике данных, где каждый документ/узел во вселенной CiteSeer определяется 3703-мерным вектором слов со значениями 0/1. Жаждете подробностей? Можете копнуть глубже в кроличью нору🐇
citeseer = load_planetoid(name='CiteSeer')
print(f"Directed: {citeseer[0].is_directed()}") print(f"Isolated Nodes: {citeseer[0].has_isolated_nodes()}") print(f"Has Self Loops: {citeseer[0].has_self_loops()}")
С дуэтом данных сети цитирования, который уже вышел на сцену, мы получили небольшой поворот в научной саге. Набор данных CiteSeer — это не только солнышко; у него есть изолированные узлы (вспомните наших одиночек❓). Теперь задача классификации будет немного сложнее с этими парнями в игре.
Вот в чем загвоздка: эти изолированные узлы бросают вызов магии агрегации GNN (мы вскоре поговорим об этом). Мы ограничены использованием только представления вектора признаков для этих изолированных узлов, что и делают модели многослойного персептрона (MLP).
Отсутствие информации о матрице смежности может привести к снижению точности. Хотя мы мало что можем сделать, чтобы решить эту проблему, мы сделаем все возможное, чтобы пролить свет на эффект отсутствия подключения 📚🔍.
# Node degree distribution node_degrees = torch_geometric.utils.degree(citeseer.edge_index[0]).numpy() node_degrees = Counter(node_degrees) # convertt to a dictionary object # Bar plot fig, ax = plt.subplots(figsize=(18, 6)) ax.set_xlabel('Node degree') ax.set_ylabel('Number of nodes') ax.set_title('CiteSeer - Node Degree Distribution') plt.bar(node_degrees.keys(), node_degrees.values(), color='#0A047A')
У большинства узлов CiteSeer есть 1 или 2 соседа. Теперь вы можете подумать: "Что в этом такого?". Ну, я вам скажу, это как вечеринка с парой друзей — уютно, но без рейва. Глобальной информации об их связях с сообществом будет не хватать. Это может стать еще одной проблемой для GNN по сравнению с Cora.
Определение проблемы
Наша миссия теперь кристально ясна: вооружившись представлением признаков каждого узла и их соединениями с соседними узлами, мы пытаемся предсказать правильную метку класса для каждого узла в данном графе.
Примечание: мы не только полагаемся на матрицу функций узлов поверхностного уровня, но и погружаемся в структуру данных, анализируя каждое взаимодействие и расшифровывая каждый шорох. Это больше связано с пониманием набора данных, чем с простыми необработанными прогнозами на основе закономерностей.
Распутывание графовых нейронных сетей
Мы собираемся демистифицировать магию, стоящую за GNN. Они представляют узлы, ребра или графы в виде числовых векторов, так что каждый узел резонирует со своими исходящими ребрами. Но в чем секрет GNN? Техника, которая привлекает всеобщее внимание: операции передача сообщений, агрегация и обновление применяются рекуррентно. Аналогией для этого может быть проведение вечеринки по соседству, где каждый узел собирает информацию со своими соседями, преобразовывает и обновляет себя, а затем также делится своими обновленными знаниями с остальной толпой. Речь идет об итеративном обновлении их векторов признаков, привнося в них локальную мудрость от их соседей по n-узлам. Познакомьтесь с этой жемчужиной: Введение в GNN, в которой ясно объясняется каждое понятие.
GNN состоят из уровней, каждый из которых расширяет свой переход для доступа к информации от соседей. Например, GNN с двумя слоями узла будет учитывать расстояние friend-of-firend
для сбора информации и обновления своего представления. Просто помните, что вселенная знаний находится на расстоянии одного клика 🖱, и когда вы будете готовы, Интернет готов стать вашим проводником. Объем этой работы состоит не в том, чтобы объяснить их в одном блоге, а в том, чтобы запачкать руки кодированием ⌨ 💻
Базовый ВНС
Мы создаем базовый класс, который закладывает основу для наших реальных моделей GNN. Это набор методов для обучения, оценки и статистики. Здесь не должно повторяться кода!
Мы также настраиваем частные методы для инициализации статистики, связанной с анимацией. Базовый класс позже будет унаследован моделями GCN и GAT, чтобы без суеты использовать общие функции. Легкая эффективность прямо у вас под рукой 🛠️📊🏗️.
# Base GNN Module class BaseGNN(torch.nn.Module): """ Base class for Graph Neural Network models. """ def __init__( self, ): super().__init__() torch.manual_seed(48) # Initialize lists to store animation-related statistics self._init_animate_stats() self.optimizer = None def _init_animate_stats(self) -> None: """Initialize animation-related statistics.""" self.embeddings = [] self.losses = [] self.train_accuracies = [] self.val_accuracies = [] self.predictions = [] def _update_animate_stats( self, embedding: torch.Tensor, loss: torch.Tensor, train_accuracy: float, val_accuracy: float, prediction: torch.Tensor, ) -> None: # Update animation-related statistics with new data self.embeddings.append(embedding) self.losses.append(loss) self.train_accuracies.append(train_accuracy) self.val_accuracies.append(val_accuracy) self.predictions.append(prediction) def accuracy(self, pred_y: torch.Tensor, y: torch.Tensor) -> float: """ Calculate accuracy between predicted and true labels. :param pred (torch.Tensor): Predicted labels. :param y (torch.Tensor): True labels. :returns: Accuracy value. """ return ((pred_y == y).sum() / len(y)).item() def fit(self, data: Data, epochs: int) -> None: """ Train the GNN model on the provided data. :param data: The dataset to use for training. :param epochs: Number of training epochs. """ # Use CrossEntropyLoss as the criterion for training criterion = torch.nn.CrossEntropyLoss() optimizer = self.optimizer self.train() for epoch in range(epochs + 1): # Training optimizer.zero_grad() _, out = self(data.x, data.edge_index) loss = criterion(out[data.train_mask], data.y[data.train_mask]) acc = self.accuracy( out[data.train_mask].argmax(dim=1), data.y[data.train_mask] ) loss.backward() optimizer.step() # Validation val_loss = criterion(out[data.val_mask], data.y[data.val_mask]) val_acc = self.accuracy( out[data.val_mask].argmax(dim=1), data.y[data.val_mask] ) kwargs = { "embedding": out.detach().cpu().numpy(), "loss": loss.detach().cpu().numpy(), "train_accuracy": acc, "val_accuracy": val_acc, "prediction": out.argmax(dim=1).detach().cpu().numpy(), } # Update animation-related statistics self._update_animate_stats(**kwargs) # Print metrics every 10 epochs if epoch % 25 == 0: print( f"Epoch {epoch:>3} | Train Loss: {loss:.3f} | Train Acc: " f"{acc * 100:>6.2f}% | Val Loss: {val_loss:.2f} | " f"Val Acc: {val_acc * 100:.2f}%" ) @torch.no_grad() def test(self, data: Data) -> float: """ Evaluate the model on the test set and return the accuracy score. :param data: The dataset to use for testing. :return: Test accuracy. """ # Set the model to evaluation mode self.eval() _, out = self(data.x, data.edge_index) acc = self.accuracy( out.argmax(dim=1)[data.test_mask], data.y[data.test_mask] ) return acc
Многоуровневая сеть перцептронов
А вот и ванильная многоуровневая сеть перцептронов! Теоретически мы могли бы предсказать категорию документа/узла, просто взглянув на его характеристики. Нет необходимости в реляционной информации — достаточно старого доброго набора слов. Чтобы проверить гипотезу, мы определяем простой двухуровневый MLP, который работает исключительно с функциями входного узла.
Сверточные сети графов
Сверточные нейронные сети (CNN) штурмом взяли сцену машинного обучения благодаря своему изящному трюку с совместным использованием параметров и способности эффективно извлекать скрытые функции. Но разве изображения не являются графиками? Путаница! Давайте будем думать о каждом пикселе как об узле, а о значениях RGB — как об элементах узла. Возникает вопрос: можно ли применить эти трюки CNN в области нерегулярных графов?
Это не так просто, как копипаст. Графики имеют свои особенности:
* **Отсутствие последовательности**. Гибкость — это здорово, но она приносит некоторый хаос. Просто подумайте о молекулах с одинаковой формулой, но разной структурой. Графики могут быть такими хитрыми.
* **Загадка порядка узлов**: Графики не имеют фиксированного порядка, в отличие от текстов или изображений. Узел подобен гостю на вечеринке — без определенного места. Алгоритмы должны быть спокойны 🕳 по поводу отсутствия иерархии узлов.
* **Проблемы масштабирования**: графики могут вырасти БОЛЬШИМИ. Представьте себе социальные сети с миллиардами пользователей и триллионами границ. Работа в таком масштабе — это не прогулка в парке. Разделение и объединение графов — это головоломка, а традиционные купания (операции) не передаются напрямую.
Мы собираем GCN, расширяя наш класс BaseGNN (обычная практика в объектно-ориентированном программировании для обеспечения наследования). Конструктор устанавливает входные, скрытые и выходные размеры, выравнивая шаги нашей сети. Мы дорабатываем оптимизатор для обновлений параметров. Прямые методы берут характеристики узла и связность графа (edge_index), выполняя свертки графа, которые являются танцевальными процедурами для узлов, вдохновленными их соседями. Активация ReLU дает толчок, ведущий к последнему действию: функция log_softmax для вероятностей классов.
class GCN(BaseGNN): """ Graph Convolutional Network model for node classification. """ def __init__( self, input_dim: int, hidden_dim: int, output_dim: int ): super().__init__() self.gcn1 = GCNConv(input_dim, hidden_dim) self.gcn2 = GCNConv(hidden_dim, output_dim) self.optimizer = torch.optim.Adam( self.parameters(), lr=0.01, weight_decay=5e-4 ) def forward( self, x: torch.Tensor, edge_index: torch.Tensor ) -> torch.Tensor: """ Forward pass of the Graph Convolutional Network model. :param (torch.Tensor): Input feature tensor. :param (torch.Tensor): Graph connectivity information :returns torch.Tensor: Output tensor. """ h = F.dropout(x, p=0.5, training=self.training) h = self.gcn1(h, edge_index).relu() h = F.dropout(h, p=0.5, training=self.training) h = self.gcn2(h, edge_index) return h, F.log_softmax(h, dim=1) class GAT(BaseGNN): def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, heads: int=8): super().__init__() torch.manual_seed(48) self.gcn1 = GATConv(input_dim, hidden_dim, heads=heads) self.gcn2 = GATConv(hidden_dim * heads, output_dim, heads=1) self.optimizer = torch.optim.Adam( self.parameters(), lr=0.01, weight_decay=5e-4 ) def forward( self, x: torch.Tensor, edge_index: torch.Tensor ) -> torch.Tensor: """ Forward pass of the Graph Convolutional Network model. :param (torch.Tensor): Input feature tensor. :param (torch.Tensor): Graph connectivity information :returns torch.Tensor: Output tensor. """ h = F.dropout(x, p=0.6, training=self.training) h = self.gcn1(h, edge_index).relu() h = F.dropout(h, p=0.6, training=self.training) h = self.gcn2(h, edge_index).relu() return h, F.log_softmax(h, dim=1)
Обучение модели
Давайте посмотрим, как скрытые представления узлов в графе изменяются с течением времени, поскольку модель проходит обучение задачам классификации узлов.
num_epochs = 200 def train_and_test_model(model, data: Data, num_epochs: int) -> tuple: """ Train and test a given model on the provided data. :param model: The PyTorch model to train and test. :param data: The dataset to use for training and testing. :param num_epochs: Number of training epochs. :return: A tuple containing the trained model and the test accuracy. """ model.fit(data, num_epochs) test_acc = model.test(data) return model, test_acc mlp = MLP( input_dim=cora.num_features, hidden_dim=16, out_dim=cora.num_classes, ) print(f"{mlp}\n", f"-"*88) mlp, test_acc_mlp = train_and_test_model(mlp, data, num_epochs) print(f"-"*88) print(f"\nTest accuracy: {test_acc_mlp * 100:.2f}%\n")
Как видно, наш MLP, кажется, борется с трудностями в центре внимания, точность теста составляет всего около 55%. Но почему MLP не работает лучше? Главный виновник — не что иное, как переобучение — модель слишком привыкла к обучающим данным, что делает ее невежественной при столкновении с новыми представлениями узлов. Это как предсказывать лейблы с одним закрытым глазом. Это также не позволяет включить в модель важную предвзятость. Именно здесь GNN вступают в игру и могут помочь повысить производительность нашей модели.
gcn = GCN( input_dim=cora.num_features, hidden_dim=16, output_dim=cora.num_classes, ) print(f"{gcn}\n", f"-"*88) gcn, test_acc_gcn = train_and_test_model(gcn, data, num_epochs) print(f"-"*88) print(f"\nTest accuracy: {test_acc_gcn * 100:.2f}%\n")
И вот оно: просто заменив эти линейные слои слоями GCN, мы взлетим до ослепительной точности теста 79%!✨ Свидетельство силы реляционной информации между узлами. Как будто мы включили прожектор данных, выявляя скрытые закономерности и связи, которые ранее были утеряны в тени. Цифры не лгут — GNN — это не просто алгоритмы; они шепчут данные.
Точно так же даже GAT работают с более высокой точностью (81 %) из-за их функции многоголового внимания.
gat = GAT( input_dim=cora.num_features, hidden_dim=8, output_dim=cora.num_classes, heads=6, ) print(f"{gat}\n", f"-"*88) gat, test_acc_gat = train_and_test_model(gat, data, num_epochs) print(f"-"*88) print(f"\nTest accuracy: {test_acc_gat * 100:.2f}%\n")
Давайте посмотрим на скрытое представление нашего набора данных CiteSeer с использованием метода уменьшения размерности TSNE. Мы используем `matplotlib` и `seaborn` для построения узлов графика.
import matplotlib.pyplot as plt from sklearn.manifold import TSNE import seaborn as sns # Get embeddings embeddings, _ = gat(citeseer[0].x, citeseer[0].edge_index) # Train TSNE tsne = TSNE(n_components=2, learning_rate='auto', init='pca').fit_transform(embeddings.detach()) # Set the Seaborn theme sns.set_theme(style="whitegrid") # Plot TSNE plt.figure(figsize=(10, 10)) plt.axis('off') sns.scatterplot(x=tsne[:, 0], y=tsne[:, 1], hue=data.y, palette="viridis", s=50) plt.legend([], [], frameon=False) plt.show()
Холст данных рисует показательную картину: узлы одного класса тяготеют друг к другу, образуя кластеры для каждой из шести меток классов. Однако выбросы изолированные узлы играют свою роль в этой драме, поскольку они изменили наши показатели точности.
Помните наше первоначальное предположение о влиянии msing edge? Что ж, гипотеза имеет свое значение. Мы проводим еще один тест, в котором я стремлюсь рассчитать производительность модели GAT, вычислив точность, классифицированную по степени узлов, тем самым выявив важность связности.
Подведение итогов
И на этом мы подошли к последнему разделу, я хотел бы резюмировать основные выводы:
- Мы увидели, почему GNN затмевают MLP, и подчеркнули ключевую роль отношений узлов.
- GAT часто превосходят GCN из-за динамического веса собственного внимания, что приводит к лучшим встраиваниям.
- Будьте осторожны с наложением слоев; слишком много слоев может привести к чрезмерному сглаживанию, когда вложения сходятся и теряют разнообразие.
Мы едва коснулись поверхности. Алгоритмы, с которыми мы познакомились — сверточные сети графов (GCN) или сети внимания графов (GAT) — это только начало. Ребра в графе, вложения узлов и симфония данных ждут дальнейшего изучения. В частности, масштабируемость имеет первостепенное значение, и я хотел бы углубиться в мини-пакетную обработку в своих следующих статьях.
Это знаменует собой мое первое путешествие в блог ⛵ , и это еще не все. Пожалуйста, подумайте о том, чтобы поделиться своими отзывами/комментариями, вашими рекомендациями по областям, которые нужно улучшить, и голосами, если это прочитано, резонировало с вами. До следующей главы пусть алгоритмы будут работать, а ваше любопытство разожжено 💡✌ 🤖.
Стремитесь управлять клетками? Получите доступ к Блокноту Google Colab!