Example

This example will help a user get familiar with PyTorchFI. By the end of this tutorial, you should be very comfortable doing the following:

  • Initializing PyTorchFI on a pretrained model

  • Injecting an error manually into a neuron during inference

  • Use the provided error models to perform an error injection

  • Understand some of the design decisions made for PyTorchFI

An interactive tutorial is also available in Google Colab.

Code Setup

Before diving into how to use PyTorchFI, we need to do some preliminary setup. Presumably, you will already have some similar code in your own project.

Here, we will download a pretrained model of AlexNet from the TorchVision repository (you should be able to use any model) and prepare a generic "image" to run through the AlexNet model.

torch.random.manual_seed(5)

# Model prep
model = models.alexnet(pretrained=True)
model.eval()

# generic image preperation
batch_size = 1
h = 224
w = 224
c = 3

image = torch.rand((batch_size, c, h, w))

Let's see what class we get when we run our generated "image" through AlexNet.

output = model(image)

golden_label = list(torch.argmax(output, dim=1))[0].item()
print("Error-free label:", golden_label)

Ok, so we see that a regular, error-free inference results in the class 556. Now, for the fun part, let's perturb the network during inference, and see what happens! We'll use PyTorchFI to perform an injection on a single neuron.

Initializing PyTorchFI

In order to perform error injections dynamically (i.e., during an inference), you need to provide PyTorchFI with the following:

  • model being perturbed

  • batch size the model is using

  • shape of the input

  • layer types to perturb

  • CUDA available (optional)

pfi_model = fault_injection(model, 
                            batch_size,
                            input_shape=[c,h,w],
                            layer_types=[torch.nn.Conv2d],
                            use_cuda=False,
                            )

The default input_shape is [3, 224, 224], which is the typical size of an image from the ImageNet dataset. So while it is optional in this tutorial, it is important to ensure the input to your model is specified if it is not ImageNet.

The default layer_type is torch.nn.Conv2d. However, you can include any set of layers to perform injections on. Under the hood, PyTorchFI will only append hooks to the layers specified. If you want to include all layers, you can simple use "all" in the list of layer types. For example, the following are some various ways to initialize PyTorchFI layer_types:

  • layer_types = [torch.nn.Conv2d]

  • layer_types = [torch.nn.Conv2d, torch.nn.Linear]

  • layer_types = [torch.nn.Linear]

  • layer_types = ["all"]

After you run fault_injection(...), PyTorchFI will setup a few internal data structures and profile your network to assist with bounds checking during error injections. To print out what PyTorchFI "sees" after the initialization, you can run print_pytorchfi_layer_summary().

print(pfi_model.print_pytorchfi_layer_summary())

Error Injection Overview

In order to perform an injection, we need to specify two things:

  1. Error site: Where to perform the injection.

  2. Error value: What value to change it to.

To address the where, there are two general error injection types: neurons and weights. We'll begin by explaining some of the subtleties of neuron error sites. Weight injections are similar and thus excluded here.

Neuron Injections Site

Specifying the location of an error injection during a dynamic execution of a DNN requires knowing the shape of the output feature maps. Since the output feature maps (fmaps for short) are dependent on a variety of things (including the stride, padding, kernel size, etc) it might not be obvious apriori what bounds you can use to specify the location of an error.

To make your life easier, we profile the network during the init step to obtain the bounds on your network. For example, if we want to perform an injection in Layer 0 (as depicted in the previous section), we know that the shape of the output feature map of the convolution is [1, 64, 55, 55], which corresponds to the [batch, C, H, W] size.

It is important to note that the Layer number will depend on the layer_types provided. In other words, the Layer # specified by the print_pytorch_fi_summary() is the number that will be used to ID the layer in PyTorchFI.

Neuron Injection Value

The second part in performing an error injection is to define the new value which will be used.

There are two possibilities here:

  1. A new value that is independent of the original value

  2. A new value that is dependent on the original value

The first scenario is straightforward: you can specify any value you want, and tell PyTorchFI to replace the original value with the erroneous value.

The second scenario requires more support, in order to read the original value, apply some function on it, and then write the new value in place of the old value. For this scenario, the user needs to provide they own custom error function to PyTorchFI (or use one from the error_models package).

Let's start with a simple injection.

b, layer, C, H, W, err_val = [0], [2], [4], [2], [4], [10000]

inj = pfi_model.declare_neuron_fi(batch=b, layer_num=layer, dim1=C, dim2=H, dim3=W, value=err_val)

inj_output = inj(image)
inj_label = list(torch.argmax(inj_output, dim=1))[0].item()
print("[Single Error] PytorchFI label:", inj_label)

In order to declare a neuron (or weight) injection, we need to call declare_neuron_fi(...) (or declare_weight_fi(...)). We then need to pass in the location of the error, which are specified by the args:

  • batch

  • layer_num

  • dim1, dim2, dim3

and the value, which will either be specified by:

  • value, if it is independent

  • function, if it is dependent

In the example above, we are injecting a value of 10000 in batch element 0, in layer 2 at the tensor location [4, 2, 4] (corresponding to CHW).

Function-based Neuron Errors

If you want to make your error injection dependent on its former value, you will need to provide a function to do that. Below is an example, where we provide a function called mul_neg_one that multiplies everything by -1:

class custom_func(fault_injection):
    def __init__(self, model, batch_size, **kwargs):
        super().__init__(model, batch_size, **kwargs)

    # define your own function
    def mul_neg_one(self, module, input, output):
        output[:] = output * -1

        self.updateLayer()
        if self.get_current_layer() >= self.get_total_layers():
            self.reset_current_layer()

pfi_model_2 = custom_func(model, 
                     batch_size,
                     input_shape=[c,h,w],
                     layer_types=[torch.nn.Conv2d],
                     use_cuda=False,
                     )

inj = pfi_model_2.declare_neuron_fi(function=pfi_model_2.mul_neg_one)

inj_output = inj(image)
inj_label = list(torch.argmax(inj_output, dim=1))[0].item()
print("[Single Error] PytorchFI label:", inj_label)

The result of this inference is the label 676. Let's explain the additional code in a bit more detail.

First, we extend the fault_injection class to provide our own method. The new method has a very specific signature: it must be similar to the hook signature defined by PyTorch. Specifically, it requires the arguments model, input, and output.

Inside the method, you can manipulate output however you wish. In the example above, we multiplied the entire tensor by -1. For a more nuanced example, take a look at the _set_value(...) method in core. In particular, the first half of the method deals with extracting the location of an error using the same methodology as a Value-based error injection. In other words, it allows you to specify the CHW location, while performing its own perturbation inside the function.

The last part of the custom method must call updateLayer(...). We use this to internally keep track of which layer the model is in during inference for error injections.

For another example on a custom neuron-based function, check out the single_bit_flip_func in neuron_error_models.py.

For an example of a weight-based function, check out the zero_func_random_weight in weight_error_models.py. Note that for custom weight functions, you do not need to extend core nor the updateLayer() function - weight injections are much simpler.

And that's it! If you have any more questions, concerns, or feature requests, either add an issue on our github page, or reach out to the developers.

Last updated