Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MoE Gating model base #33

Open
ifrit98 opened this issue Nov 22, 2023 · 0 comments
Open

Add MoE Gating model base #33

ifrit98 opened this issue Nov 22, 2023 · 0 comments

Comments

@ifrit98
Copy link
Contributor

ifrit98 commented Nov 22, 2023

Create a gating model in a Mixture of Experts (MoE) architecture using PyTorch. We can implement a soft gating mechanism where the weights act as probabilities for selecting different experts. We can use the Gumbel-Softmax trick to sample from the categorical distribution with temperature, making the sampling process differentiable.

This should be part of the validator, as most subnets will want some kind of automatic routing mechanism without having to reinvent the wheel.

import torch
import torch.nn as nn
import torch.nn.functional as F

class GatingModel(nn.Module):
    def __init__(self, input_dim, num_experts, temperature=1.0):
        super(GatingModel, self).__init__()

        self.num_experts = num_experts
        self.temperature = temperature

        # Gating network
        self.gating_network = nn.Sequential(
            nn.Linear(input_dim, num_experts),
            nn.Softmax(dim=-1)  # Softmax along the expert dimension
        )

    def forward(self, input):
        # Calculate gating probabilities
        gating_probs = self.gating_network(input)

        # Gumbel-Softmax sampling for discrete selection
        gumbel_noise = torch.rand_like(gating_probs)
        gumbel_noise = -torch.log(-torch.log(gumbel_noise + 1e-20) + 1e-20)  # Gumbel noise
        logits = (torch.log(gating_probs + 1e-20) + gumbel_noise) / self.temperature
        selected_experts = F.softmax(logits, dim=-1)

        # Weighted sum of expert outputs
        output = torch.sum(selected_experts.unsqueeze(-1) * input.unsqueeze(-2), dim=-2)

        return output, selected_experts

# Example usage
input_dim = 10
num_experts = 5
temperature = 0.1

# Create a GatingModel
gating_model = GatingModel(input_dim, num_experts, temperature)

# Generate dummy input
input_data = torch.randn(32, input_dim)

# Forward pass through the gating model
output, selected_experts = gating_model(input_data)

# The 'output' is the final output of the MoE, and 'selected_experts' is the one-hot vector indicating which experts were selected for each example.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant