Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

item() support for vmap #1108

Open
baskuit opened this issue Jan 28, 2023 · 2 comments
Open

item() support for vmap #1108

baskuit opened this issue Jan 28, 2023 · 2 comments

Comments

@baskuit
Copy link

baskuit commented Jan 28, 2023

Is this being considered? Perhaps I'm not cut out for this, but every single time I've tried to use vmap I've ran into this restriction, resulting in me having to abandon the approach.

@kshitij12345
Copy link
Contributor

Thanks for reporting the issue @baskuit.

It would be great if you can share a minimal repro script so that we can prioritise the required ops. Thanks!

cc: @Chillee

@JustinS6626
Copy link

This is not quite a minimal repro script, but it illustrates the issue:

import time
import os
import copy
from copy import deepcopy
# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, transforms
from torch.func import functional_call
from torch.func import stack_module_state
import copy
from torch import vmap

# Pennylane
import pennylane as qml
from pennylane import numpy as np

torch.manual_seed(42)
#np.random.seed(42)
rng = np.random.default_rng(12345)

# Plotting
import matplotlib.pyplot as plt

n_qubits = 4                # Number of qubits
step = 0.0004               # Learning rate
batch_size = 4              # Number of samples for each training step
num_epochs = 3              # Number of training epochs
q_depth = 6                 # Depth of the quantum circuit (number of variational layers)
gamma_lr_scheduler = 0.1    # Learning rate reduction applied every 10 epochs.
q_delta = 0.01              # Initial spread of random quantum weights
start_time = time.time()    # Start of the computation timer
num_models = 10

#dev = qml.device("default.mixed", wires=n_qubits)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_transforms = {
    "train": transforms.Compose(
        [
            # transforms.RandomResizedCrop(224),     # uncomment for data augmentation
            # transforms.RandomHorizontalFlip(),     # uncomment for data augmentation
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            # Normalize input channels using mean values and standard deviations of ImageNet.
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

data_dir = "hymenoptera_data"
image_datasets = {
    x if x == "train" else "validation": datasets.ImageFolder(
        os.path.join(data_dir, x), data_transforms[x]
    )
    for x in ["train", "val"]
}
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "validation"]}
class_names = image_datasets["train"].classes

# Initialize dataloader
dataloaders = {
    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True)
    for x in ["train", "validation"]
}

inputs, classes = next(iter(dataloaders["validation"]))

def H_layer(nqubits):
    """Layer of single-qubit Hadamard gates.
    """
    for idx in range(nqubits):
        qml.Hadamard(wires=idx)


def RY_layer(w):
    """Layer of parametrized qubit rotations around the y axis.
    """
    for idx, element in enumerate(w):
        qml.RY(element, wires=idx)


def entangling_layer(nqubits):
    """Layer of CNOTs followed by another shifted layer of CNOT.
    """
    # In other words it should apply something like :
    # CNOT  CNOT  CNOT  CNOT...  CNOT
    #   CNOT  CNOT  CNOT...  CNOT
    for i in range(0, nqubits - 1, 2):  # Loop over even indices: i=0,2,...N-2
        qml.CNOT(wires=[i, i + 1])
    for i in range(1, nqubits - 1, 2):  # Loop over odd indices:  i=1,3,...N-3
        qml.CNOT(wires=[i, i + 1])


#@qml.qnode(dev)
def quantum_net(q_input_features, q_weights_flat, noise_probs):
    """
    The variational quantum circuit.
    """
    #noise_probs = np.random.uniform(size=n_qubits)
    #noise_probs = rng.integers(2, size=n_qubits)
    print("Noise probabilities:")
    print(noise_probs)
    # Reshape weights
    q_weights = q_weights_flat.reshape(q_depth, n_qubits)

    # Start from state |+> , unbiased w.r.t. |0> and |1>
    H_layer(n_qubits)

    # Embed features in the quantum node
    RY_layer(q_input_features)
    for i in range(n_qubits):
        qml.BitFlip(noise_probs[i], wires=i)
        print("bit flip done")
    # Sequence of trainable variational layers
    for k in range(q_depth):
        entangling_layer(n_qubits)
        RY_layer(q_weights[k])

    # Expectation values in the Z basis
    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]
    print("Exp vals calculated")
    return tuple(exp_vals)


class DressedQuantumNet(nn.Module):
    """
    Torch module implementing the *dressed* quantum net.
    """

    def __init__(self):
        """
        Definition of the *dressed* layout.
        """

        super().__init__()
        weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        self.premodel = torchvision.models.resnet18(weights=weights)
        self.premodel.fc = nn.Linear(512, n_qubits)
        self.q_params = nn.Parameter(q_delta * torch.randn(q_depth * n_qubits))
        self.post_net = nn.Linear(n_qubits, 2)


    def forward(self, input_features, bitflips):
        """
        Defining how tensors are supposed to move through the *dressed* quantum
        net.
        """

        # obtain the input features for the quantum circuit
        # by reducing the feature dimension from 512 to
        dev = qml.device("default.mixed", wires=n_qubits)
        node = qml.QNode(quantum_net, dev, interface="torch", diff_method="backprop")
        #flips = bitflips.clone().detach().cpu().numpy()
        flips = bitflips.item()
        pre_out = self.premodel(input_features)
        q_in = torch.tanh(pre_out) * np.pi / 2.0

        # Apply the quantum circuit to each element of the batch and append to q_out
        q_out = torch.Tensor(0, n_qubits)
        q_out = q_out.to(device)
        for elem in q_in:
            q_out_elem = torch.hstack(node(elem, self.q_params, flips)).float().unsqueeze(0)
            q_out = torch.cat((q_out, q_out_elem))
            print("Q out calculated")

        # return the two-dimensional prediction from the postprocessing layer
        return self.post_net(q_out)


##weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
##model_hybrid = torchvision.models.resnet18(weights=weights)



# Notice that model_hybrid.fc is the last layer of ResNet18
model_hybrid = DressedQuantumNet()

for param in model_hybrid.parameters():
    param.requires_grad = False

# Use CUDA or CPU according to the "device" object.
model_hybrid = model_hybrid.to(device)

criterion = nn.CrossEntropyLoss()
optimizer_hybrid = optim.Adam(model_hybrid.parameters(), lr=step)

exp_lr_scheduler = lr_scheduler.StepLR(
    optimizer_hybrid, step_size=10, gamma=gamma_lr_scheduler
)

def train_model(model, criterion, optimizer, scheduler, num_epochs):
    
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_loss = 10000.0  # Large arbitrary number
    best_acc_train = 0.0
    best_loss_train = 10000.0  # Large arbitrary number
    print("Training started:")

    for epoch in range(num_epochs):
        models = [deepcopy(model) for _ in range(num_models)]
        params, buffers = stack_module_state(models)
        base_model = models[0]
        base_model = base_model.to('meta')
        def fmodel(params, buffers, x, flips):
            return functional_call(base_model, (params, buffers), (x, flips))
        # Each epoch has a training and validation phase
        for phase in ["train", "validation"]:
            if phase == "train":
                # Set model to training mode
                model.train()
            else:
                # Set model to evaluate mode
                model.eval()
            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            n_batches = dataset_sizes[phase] // batch_size
            it = 0
            for inputs, labels in dataloaders[phase]:
                since_batch = time.time()
                batch_size_ = len(inputs)
                inputs = inputs.to(device)
                labels = labels.to(device)
                all_inputs = torch.stack([inputs.clone() for i in range(num_models)])
                all_labels = torch.stack([labels.clone() for i in range(num_models)])
                bitflips = torch.randint(2, (num_models, n_qubits)).to(device)
                optimizer.zero_grad()

                # Track/compute gradient and make an optimization step only when training
                with torch.set_grad_enabled(phase == "train"):
                    #outputs = model(inputs)
                    outputs = vmap(fmodel)(params, buffers, all_inputs, bitflips)
                    print(outputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, all_labels)
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # Print iteration results
                running_loss += loss.item() * batch_size_
                batch_corrects = torch.sum(preds == labels.data).item()
                running_corrects += batch_corrects
                print(
                    "Phase: {} Epoch: {}/{} Iter: {}/{} Batch time: {:.4f}".format(
                        phase,
                        epoch + 1,
                        num_epochs,
                        it + 1,
                        n_batches + 1,
                        time.time() - since_batch,
                    ),
                    end="\r",
                    flush=True,
                )
                it += 1

            # Print epoch results
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]
            print(
                "Phase: {} Epoch: {}/{} Loss: {:.4f} Acc: {:.4f}        ".format(
                    "train" if phase == "train" else "validation  ",
                    epoch + 1,
                    num_epochs,
                    epoch_loss,
                    epoch_acc,
                )
            )

            # Check if this is the best model wrt previous epochs
            if phase == "validation" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == "validation" and epoch_loss < best_loss:
                best_loss = epoch_loss
            if phase == "train" and epoch_acc > best_acc_train:
                best_acc_train = epoch_acc
            if phase == "train" and epoch_loss < best_loss_train:
                best_loss_train = epoch_loss

            # Update learning rate
            if phase == "train":
                scheduler.step()

    # Print final results
    model.load_state_dict(best_model_wts)
    time_elapsed = time.time() - since
    print(
        "Training completed in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60)
    )
    print("Best test loss: {:.4f} | Best test accuracy: {:.4f}".format(best_loss, best_acc))
    return model


model_hybrid = train_model(
    model_hybrid, criterion, optimizer_hybrid, exp_lr_scheduler, num_epochs=num_epochs
)

NB you need to have pennylane installed to execute it.

When executed, it outputs this error:

Traceback (most recent call last):
  File "/usr/lib/python3.9/idlelib/run.py", line 559, in runcode
    exec(code, self.locals)
  File "/home/justin/RandomEnsembleToyExample.py", line 312, in <module>
    model_hybrid = train_model(
  File "/home/justin/RandomEnsembleToyExample.py", line 248, in train_model
    outputs = vmap(fmodel)(params, buffers, all_inputs, bitflips)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/vmap.py", line 278, in vmap_impl
    return _flat_vmap(
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/vmap.py", line 44, in fn
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/vmap.py", line 391, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/justin/RandomEnsembleToyExample.py", line 220, in fmodel
    return functional_call(base_model, (params, buffers), (x, flips))
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/functional_call.py", line 143, in functional_call
    return nn.utils.stateless._functional_call(
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/utils/stateless.py", line 263, in _functional_call
    return module(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/justin/RandomEnsembleToyExample.py", line 167, in forward
    flips = bitflips.item()
RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

It would be great if support for the Tensor.item() function could be added in the next update of vmap.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants