Графовые нейронные сети (GNN)! Представьте их сватами вселенной ИИ, неустанно помогающими точкам данных находить друзей и популярность, исследуя их связи. Лучший ведомый на цифровой вечеринке.

Вы спросите, почему эти GNN так важны? Ну, в реальной жизни вроде бы все взаимосвязано. Мы говорим о таких вещах, как социальные сети, Всемирная паутина, сети частиц и даже изоморфный танец молекул (спросите Уолтера Уайта). А вот и ошеломляющее открытие: даже «прямолинейные» структуры данных, такие как тексты, изображения и табличные форматы, можно представить в виде графиков!🧲. Это все равно, что превратить ваши скучные вечеринки с данными, достойными повторения, в моменты озарения! Поверьте, возможности безграничны.

Но подождите, а что за набор этих новых парней из толпы ИИ? Ну, они как классные родственники сверточных и последовательных моделей машинного обучения (ML). Их архитектура вдохновлена ​​словом flexible , поэтому они скручиваются в воронкообразный пирог (если воронкообразные пироги были структурами данных) только для того, чтобы расшифровывать сложные отношения и решать проблемы, вызывающие зависть даже у Шерлока 🕵.

Содержание

В этой статье мы поговорим об основах графовых структур данных и основанных на графах архитектурах машинного обучения. Подробные объяснения выходят за рамки этой работы, и я предоставил полезные ссылки везде, где это возможно. Кроме того, мы будем создавать некоторые модели с использованием PyTorch Geometric (PyG) (наш плащ Супермена) со следующей дорожной картой:

  1. Краткий обзор наборов графических данных и представление набора данных Planetoid. Кроме того, мы определим нашу постановку задачи ML прямо здесь.
  2. Разверните архитектуру GNN и несколько умных формул из воздуха.
  3. Нет, мы не пропускаем занятия! Поэтому требуется введение в модели PyTorch с адаптированными классами Python.
  4. Затем мы обучаем модели и тестируем наши творения. Окончательное противостояние, в котором наши GNN будут сражаться с набором данных.
  5. Подведение итогов и основные выводы.

Пристегнитесь, это будет увлекательное путешествие! 🚀📊

Хорошо, давайте поговорим о наборах графических данных — цифровой игровой площадке, где точки данных тусуются, делятся историями, а иногда даже сплетничают. Представьте их как те взаимосвязанные социальные круги, которые вы найдете на вечеринках, но вместо людей у ​​вас есть узлы, и информация делится между ними через ребра. Теперь узлы и ребра не просто стоят и показывают свои виртуальные большие пальцы вверх 👍 вниз 👎. они - ⭐ шоу, каждый из которых имеет свой набор функций и атрибутов.

Но подождите, мы не собираемся плести все это с нуля. Нет, мы не настолько амбициозны. Давайте поприветствуем пакет Planetoid от PyTorch Geometric и уменьшим шаблон. Это похоже на план построения графика вашей мечты без особых усилий. Набор Лего набор, позволяющий исследователям управлять размером графика, его соединениями и разделением данных.

Cora, классический эталонный набор данных сети цитирования из статьи «Пересмотр полууправляемого обучения с помощью графических вложений». В этом наборе данных каждая исследовательская работа является узлом, а ребра? Ах, они словно невидимые нити, соединяющие статьи через цитаты 📚🤓

Теперь каждый из этих бумажных гостей приходит с подарками — в частности, с пакетом словесных представлений о его содержании. Это праздник словарного запаса, где каждый вектор признаков узла показывает наличие (1) или отсутствие (0) определенного слова из общего числа 1433 вариантов. И позвольте мне сказать вам, что эти газеты — острые едоки; они заботятся только об определенных словах.

В мире науки Cora — лучший выбор для оценки GNN и других методов в таких задачах, как классификация узлов и прогнозирование ссылок. И помните, что на этой вечеринке цитаты (edges) — лучший ледокол! ➡️

Восторги Коры

  • x=[2708, 1433] – это матрица признаков узла. Представьте себе: имеется 2708 документов, и каждый из них представлен 1433-мерным вектором признаков, все они закодированы в горячем режиме.
  • edge_index=[2, 10556] представляет собой связность графа. Это говорит о том, кто с кем тусуется с формой (2, количество направленных ребер). 📩
  • y=[2708] – это ярлык достоверности. Каждая нода отнесена к одному классу ровно, чтобы не было неловкого момента — «Ну и что ты исследуешь? 😆
  • train_mask[2708], val_mask[2708], test_mask[2708] — это необязательные атрибуты, которые помогают разделить набор данных на наборы для обучения, проверки и тестирования соответственно. Булевы значения, присутствующие в них, утверждают, что правильные узлы смешиваются в правильных местах.

Давайте остановимся на мгновение, чтобы подумать. С вектором признаков из 1433 слов можно легко использовать модель MLP 👷 для какой-то старой доброй классификации узлов/документов. Но эй, мы не из тех, кто соглашается на обычное 🔎. Мы собираемся пересечь границу с edge_index до 🤾 и нырнуть с головой в эти отношения, чтобы усилить наши прогнозы. Итак, давайте серьезно подключимся здесь! 🤝

# Let us talk more about edge index/graph connectivity
print(f"Shape of graph connectivity: {cora[0].edge_index.shape}")
print(cora[0].edge_index)

edge_index интересен тем, что содержит два списка, в первом из которых шепчутся идентификаторы узлов-источников, а во втором раскрываются сведения об их пунктах назначения. У этой настройки причудливое название: список координат (COO). Это отличный способ эффективно хранить разреженные матрицы, например, когда у вас есть узлы, которые не совсем болтают со всеми в комнате.

Теперь я знаю, о чем вы думаете. Почему бы не использовать простую матрицу смежности? Что ж, в области графических данных не каждый узел является социальной бабочкой. Эти матрицы смежности? Они будут плавать в море нулей, а это не самая эффективная конфигурация памяти. Вот почему главный операционный директор — наш основной подход 🧩, а PyG гарантирует, что края изначально направлены.

# The adjacency matrix can be inferred from the edge_index with a utility function.

adj_matrix = torch_geometric.utils.to_dense_adj(cora[0].edge_index)[0].numpy().astype(int)
print(f'Shape: {adj_matrix.shape}\nAdjacency matrix: \n{adj_matrix}')
# Some more PyG utility functions
print(f"Directed: {cora[0].is_directed()}")
print(f"Isolated Nodes: {cora[0].has_isolated_nodes()}")
print(f"Has Self Loops: {cora[0].has_self_loops()}")

Объект Data имеет множество замечательных полезных функций, и давайте рассмотрим три примера:

  • is_directed сообщает, является ли граф ориентированным, т. е. матрица смежности несимметрична.
  • has_isolated_edges выявляет узлы-одиночки, оторванные от суетливой толпы. Эти разъединенные души подобны кусочкам головоломки без полной картины, что делает последующие задачи машинного обучения настоящей головной болью.
  • has_self_loops сообщает, находится ли узел в отношениях сам с собой ❣

Давайте кратко поговорим о визуализации. Преобразование объектов PyG Data в объекты NetworkX графа и их построение — это простая прогулка. Но, придержите лошадей! Наш список гостей (количество узлов) имеет длину более 2 тыс., поэтому попытка визуализировать его будет подобна втискиванию футбольного стадиона в вашу гостиную. Да, ты не хочешь этого ⛔. Так что, пока мы не участвуем в сюжетных вечеринках, просто знайте, что этот граф заряжен и готов к серьезным сетевым действиям, даже если все это происходит за кадром 🌐🕵️‍♀️

CiteSeer – это научный 🎓 брат Коры из семейства Платеноидов. Он стоит на сцене с 3327 научными статьями, где каждый узел имеет ровно одну из 6 элитных категорий (меток классов). Теперь давайте поговорим о статистике данных, где каждый документ/узел во вселенной CiteSeer определяется 3703-мерным вектором слов со значениями 0/1. Жаждете подробностей? Можете копнуть глубже в кроличью нору🐇

citeseer = load_planetoid(name='CiteSeer')

print(f"Directed: {citeseer[0].is_directed()}")
print(f"Isolated Nodes: {citeseer[0].has_isolated_nodes()}")
print(f"Has Self Loops: {citeseer[0].has_self_loops()}")

С дуэтом данных сети цитирования, который уже вышел на сцену, мы получили небольшой поворот в научной саге. Набор данных CiteSeer — это не только солнышко; у него есть изолированные узлы (вспомните наших одиночек❓). Теперь задача классификации будет немного сложнее с этими парнями в игре.

Вот в чем загвоздка: эти изолированные узлы бросают вызов магии агрегации GNN (мы вскоре поговорим об этом). Мы ограничены использованием только представления вектора признаков для этих изолированных узлов, что и делают модели многослойного персептрона (MLP).

Отсутствие информации о матрице смежности может привести к снижению точности. Хотя мы мало что можем сделать, чтобы решить эту проблему, мы сделаем все возможное, чтобы пролить свет на эффект отсутствия подключения 📚🔍.

# Node degree distribution

node_degrees = torch_geometric.utils.degree(citeseer.edge_index[0]).numpy()
node_degrees = Counter(node_degrees)  # convertt to a dictionary object

# Bar plot
fig, ax = plt.subplots(figsize=(18, 6))
ax.set_xlabel('Node degree')
ax.set_ylabel('Number of nodes')
ax.set_title('CiteSeer - Node Degree Distribution')
plt.bar(node_degrees.keys(),
        node_degrees.values(),
        color='#0A047A')

У большинства узлов CiteSeer есть 1 или 2 соседа. Теперь вы можете подумать: "Что в этом такого?". Ну, я вам скажу, это как вечеринка с парой друзей — уютно, но без рейва. Глобальной информации об их связях с сообществом будет не хватать. Это может стать еще одной проблемой для GNN по сравнению с Cora.

Определение проблемы

Наша миссия теперь кристально ясна: вооружившись представлением признаков каждого узла и их соединениями с соседними узлами, мы пытаемся предсказать правильную метку класса для каждого узла в данном графе.

Примечание: мы не только полагаемся на матрицу функций узлов поверхностного уровня, но и погружаемся в структуру данных, анализируя каждое взаимодействие и расшифровывая каждый шорох. Это больше связано с пониманием набора данных, чем с простыми необработанными прогнозами на основе закономерностей.

Распутывание графовых нейронных сетей

Мы собираемся демистифицировать магию, стоящую за GNN. Они представляют узлы, ребра или графы в виде числовых векторов, так что каждый узел резонирует со своими исходящими ребрами. Но в чем секрет GNN? Техника, которая привлекает всеобщее внимание: операции передача сообщений, агрегация и обновление применяются рекуррентно. Аналогией для этого может быть проведение вечеринки по соседству, где каждый узел собирает информацию со своими соседями, преобразовывает и обновляет себя, а затем также делится своими обновленными знаниями с остальной толпой. Речь идет об итеративном обновлении их векторов признаков, привнося в них локальную мудрость от их соседей по n-узлам. Познакомьтесь с этой жемчужиной: Введение в GNN, в которой ясно объясняется каждое понятие.

GNN состоят из уровней, каждый из которых расширяет свой переход для доступа к информации от соседей. Например, GNN с двумя слоями узла будет учитывать расстояние friend-of-firend для сбора информации и обновления своего представления. Просто помните, что вселенная знаний находится на расстоянии одного клика 🖱, и когда вы будете готовы, Интернет готов стать вашим проводником. Объем этой работы состоит не в том, чтобы объяснить их в одном блоге, а в том, чтобы запачкать руки кодированием ⌨ 💻

Базовый ВНС

Мы создаем базовый класс, который закладывает основу для наших реальных моделей GNN. Это набор методов для обучения, оценки и статистики. Здесь не должно повторяться кода!

Мы также настраиваем частные методы для инициализации статистики, связанной с анимацией. Базовый класс позже будет унаследован моделями GCN и GAT, чтобы без суеты использовать общие функции. Легкая эффективность прямо у вас под рукой 🛠️📊🏗️.

# Base GNN Module

class BaseGNN(torch.nn.Module):
    """
    Base class for Graph Neural Network models.
    """

    def __init__(
        self,
    ):
        super().__init__()
        torch.manual_seed(48)
        # Initialize lists to store animation-related statistics
        self._init_animate_stats()
        self.optimizer = None

    def _init_animate_stats(self) -> None:
        """Initialize animation-related statistics."""
        self.embeddings = []
        self.losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.predictions = []

    def _update_animate_stats(
        self,
        embedding: torch.Tensor,
        loss: torch.Tensor,
        train_accuracy: float,
        val_accuracy: float,
        prediction: torch.Tensor,
    ) -> None:
        # Update animation-related statistics with new data
        self.embeddings.append(embedding)
        self.losses.append(loss)
        self.train_accuracies.append(train_accuracy)
        self.val_accuracies.append(val_accuracy)
        self.predictions.append(prediction)

    def accuracy(self, pred_y: torch.Tensor, y: torch.Tensor) -> float:
        """
        Calculate accuracy between predicted and true labels.

        :param pred (torch.Tensor): Predicted labels.
        :param y (torch.Tensor): True labels.

        :returns: Accuracy value.
        """
        return ((pred_y == y).sum() / len(y)).item()

    def fit(self, data: Data, epochs: int) -> None:
        """
        Train the GNN model on the provided data.

        :param data: The dataset to use for training.
        :param epochs: Number of training epochs.
        """
        # Use CrossEntropyLoss as the criterion for training
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = self.optimizer

        self.train()
        for epoch in range(epochs + 1):
            # Training
            optimizer.zero_grad()
            _, out = self(data.x, data.edge_index)
            loss = criterion(out[data.train_mask], data.y[data.train_mask])
            acc = self.accuracy(
                out[data.train_mask].argmax(dim=1), data.y[data.train_mask]
            )
            loss.backward()
            optimizer.step()

            # Validation
            val_loss = criterion(out[data.val_mask], data.y[data.val_mask])
            val_acc = self.accuracy(
                out[data.val_mask].argmax(dim=1), data.y[data.val_mask]
            )
            kwargs = {
                "embedding": out.detach().cpu().numpy(),
                "loss": loss.detach().cpu().numpy(),
                "train_accuracy": acc,
                "val_accuracy": val_acc,
                "prediction": out.argmax(dim=1).detach().cpu().numpy(),
            }

            # Update animation-related statistics
            self._update_animate_stats(**kwargs)
            # Print metrics every 10 epochs
            if epoch % 25 == 0:
                print(
                    f"Epoch {epoch:>3} | Train Loss: {loss:.3f} | Train Acc: "
                    f"{acc * 100:>6.2f}% | Val Loss: {val_loss:.2f} | "
                    f"Val Acc: {val_acc * 100:.2f}%"
                )

    @torch.no_grad()
    def test(self, data: Data) -> float:
        """
        Evaluate the model on the test set and return the accuracy score.

        :param data: The dataset to use for testing.
        :return: Test accuracy.
        """
        # Set the model to evaluation mode
        self.eval()
        _, out = self(data.x, data.edge_index)
        acc = self.accuracy(
            out.argmax(dim=1)[data.test_mask], data.y[data.test_mask]
        )
        return acc

Многоуровневая сеть перцептронов

А вот и ванильная многоуровневая сеть перцептронов! Теоретически мы могли бы предсказать категорию документа/узла, просто взглянув на его характеристики. Нет необходимости в реляционной информации — достаточно старого доброго набора слов. Чтобы проверить гипотезу, мы определяем простой двухуровневый MLP, который работает исключительно с функциями входного узла.

Сверточные сети графов

Сверточные нейронные сети (CNN) штурмом взяли сцену машинного обучения благодаря своему изящному трюку с совместным использованием параметров и способности эффективно извлекать скрытые функции. Но разве изображения не являются графиками? Путаница! Давайте будем думать о каждом пикселе как об узле, а о значениях RGB — как об элементах узла. Возникает вопрос: можно ли применить эти трюки CNN в области нерегулярных графов?

Это не так просто, как копипаст. Графики имеют свои особенности:

* **Отсутствие последовательности**. Гибкость — это здорово, но она приносит некоторый хаос. Просто подумайте о молекулах с одинаковой формулой, но разной структурой. Графики могут быть такими хитрыми.

* **Загадка порядка узлов**: Графики не имеют фиксированного порядка, в отличие от текстов или изображений. Узел подобен гостю на вечеринке — без определенного места. Алгоритмы должны быть спокойны 🕳 по поводу отсутствия иерархии узлов.

* **Проблемы масштабирования**: графики могут вырасти БОЛЬШИМИ. Представьте себе социальные сети с миллиардами пользователей и триллионами границ. Работа в таком масштабе — это не прогулка в парке. Разделение и объединение графов — это головоломка, а традиционные купания (операции) не передаются напрямую.

Мы собираем GCN, расширяя наш класс BaseGNN (обычная практика в объектно-ориентированном программировании для обеспечения наследования). Конструктор устанавливает входные, скрытые и выходные размеры, выравнивая шаги нашей сети. Мы дорабатываем оптимизатор для обновлений параметров. Прямые методы берут характеристики узла и связность графа (edge_index), выполняя свертки графа, которые являются танцевальными процедурами для узлов, вдохновленными их соседями. Активация ReLU дает толчок, ведущий к последнему действию: функция log_softmax для вероятностей классов.

class GCN(BaseGNN):
    """
    Graph Convolutional Network model for node classification.
    """

    def __init__(
        self, input_dim: int, hidden_dim: int, output_dim: int
    ):
        super().__init__()
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, output_dim)
        self.optimizer = torch.optim.Adam(
            self.parameters(), lr=0.01, weight_decay=5e-4
        )

    def forward(
        self, x: torch.Tensor, edge_index: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass of the Graph Convolutional Network model.

        :param (torch.Tensor): Input feature tensor.
        :param (torch.Tensor): Graph connectivity information
        :returns torch.Tensor: Output tensor.
        """
        h = F.dropout(x, p=0.5, training=self.training)
        h = self.gcn1(h, edge_index).relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.gcn2(h, edge_index)
 
       return h, F.log_softmax(h, dim=1)
class GAT(BaseGNN):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
                 heads: int=8):
        super().__init__()
        torch.manual_seed(48)
        self.gcn1 = GATConv(input_dim, hidden_dim, heads=heads)
        self.gcn2 = GATConv(hidden_dim * heads, output_dim, heads=1)
        self.optimizer = torch.optim.Adam(
            self.parameters(), lr=0.01, weight_decay=5e-4
        )

    def forward(
            self, x: torch.Tensor, edge_index: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass of the Graph Convolutional Network model.

        :param (torch.Tensor): Input feature tensor.
        :param (torch.Tensor): Graph connectivity information
        :returns torch.Tensor: Output tensor.
        """
        h = F.dropout(x, p=0.6, training=self.training)
        h = self.gcn1(h, edge_index).relu()
        h = F.dropout(h, p=0.6, training=self.training)
        h = self.gcn2(h, edge_index).relu()
        return h, F.log_softmax(h, dim=1)

Обучение модели

Давайте посмотрим, как скрытые представления узлов в графе изменяются с течением времени, поскольку модель проходит обучение задачам классификации узлов.

num_epochs = 200
def train_and_test_model(model, data: Data, num_epochs: int) -> tuple:
    """
    Train and test a given model on the provided data.

    :param model: The PyTorch model to train and test.
    :param data: The dataset to use for training and testing.
    :param num_epochs: Number of training epochs.
    :return: A tuple containing the trained model and the test accuracy.
    """
    model.fit(data, num_epochs)
    test_acc = model.test(data)
    return model, test_acc

mlp = MLP(
    input_dim=cora.num_features,
    hidden_dim=16,
    out_dim=cora.num_classes,
)
print(f"{mlp}\n", f"-"*88)
mlp, test_acc_mlp = train_and_test_model(mlp, data, num_epochs)
print(f"-"*88)
print(f"\nTest accuracy: {test_acc_mlp * 100:.2f}%\n")

Как видно, наш MLP, кажется, борется с трудностями в центре внимания, точность теста составляет всего около 55%. Но почему MLP не работает лучше? Главный виновник — не что иное, как переобучение — модель слишком привыкла к обучающим данным, что делает ее невежественной при столкновении с новыми представлениями узлов. Это как предсказывать лейблы с одним закрытым глазом. Это также не позволяет включить в модель важную предвзятость. Именно здесь GNN вступают в игру и могут помочь повысить производительность нашей модели.

gcn = GCN(
    input_dim=cora.num_features,
    hidden_dim=16,
    output_dim=cora.num_classes,
)
print(f"{gcn}\n", f"-"*88)
gcn, test_acc_gcn = train_and_test_model(gcn, data, num_epochs)
print(f"-"*88)
print(f"\nTest accuracy: {test_acc_gcn * 100:.2f}%\n")

И вот оно: просто заменив эти линейные слои слоями GCN, мы взлетим до ослепительной точности теста 79%!✨ Свидетельство силы реляционной информации между узлами. Как будто мы включили прожектор данных, выявляя скрытые закономерности и связи, которые ранее были утеряны в тени. Цифры не лгут — GNN — это не просто алгоритмы; они шепчут данные.

Точно так же даже GAT работают с более высокой точностью (81 %) из-за их функции многоголового внимания.

gat = GAT(
    input_dim=cora.num_features,
    hidden_dim=8,
    output_dim=cora.num_classes,
    heads=6,
)
print(f"{gat}\n", f"-"*88)
gat, test_acc_gat = train_and_test_model(gat, data, num_epochs)
print(f"-"*88)
print(f"\nTest accuracy: {test_acc_gat * 100:.2f}%\n")

Давайте посмотрим на скрытое представление нашего набора данных CiteSeer с использованием метода уменьшения размерности TSNE. Мы используем `matplotlib` и `seaborn` для построения узлов графика.

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns

# Get embeddings
embeddings, _ = gat(citeseer[0].x, citeseer[0].edge_index)

# Train TSNE
tsne = TSNE(n_components=2, learning_rate='auto',
            init='pca').fit_transform(embeddings.detach())

# Set the Seaborn theme
sns.set_theme(style="whitegrid")

# Plot TSNE
plt.figure(figsize=(10, 10))
plt.axis('off')
sns.scatterplot(x=tsne[:, 0], y=tsne[:, 1], hue=data.y, palette="viridis", s=50)
plt.legend([], [], frameon=False)
plt.show()

Холст данных рисует показательную картину: узлы одного класса тяготеют друг к другу, образуя кластеры для каждой из шести меток классов. Однако выбросы изолированные узлы играют свою роль в этой драме, поскольку они изменили наши показатели точности.

Помните наше первоначальное предположение о влиянии msing edge? Что ж, гипотеза имеет свое значение. Мы проводим еще один тест, в котором я стремлюсь рассчитать производительность модели GAT, вычислив точность, классифицированную по степени узлов, тем самым выявив важность связности.

Подведение итогов

И на этом мы подошли к последнему разделу, я хотел бы резюмировать основные выводы:

  1. Мы увидели, почему GNN затмевают MLP, и подчеркнули ключевую роль отношений узлов.
  2. GAT часто превосходят GCN из-за динамического веса собственного внимания, что приводит к лучшим встраиваниям.
  3. Будьте осторожны с наложением слоев; слишком много слоев может привести к чрезмерному сглаживанию, когда вложения сходятся и теряют разнообразие.

Мы едва коснулись поверхности. Алгоритмы, с которыми мы познакомились — сверточные сети графов (GCN) или сети внимания графов (GAT) — это только начало. Ребра в графе, вложения узлов и симфония данных ждут дальнейшего изучения. В частности, масштабируемость имеет первостепенное значение, и я хотел бы углубиться в мини-пакетную обработку в своих следующих статьях.

Это знаменует собой мое первое путешествие в блог ⛵ , и это еще не все. Пожалуйста, подумайте о том, чтобы поделиться своими отзывами/комментариями, вашими рекомендациями по областям, которые нужно улучшить, и голосами, если это прочитано, резонировало с вами. До следующей главы пусть алгоритмы будут работать, а ваше любопытство разожжено 💡✌ 🤖.

Стремитесь управлять клетками? Получите доступ к Блокноту Google Colab!