How Can I Calculate Flops And Params Without 0 Weights Neurons Affected?
My Prune code is shown below, after running this, I will get a file named 'pruned_model.pth'. import torch from torch import nn import torch.nn.utils.prune as prune import torch.nn
Solution 1:
One thing you could do is to exclude the weights below a certain threshold from the FLOPs computation. To do so you would have to modify the flop counter functions.
I'll provide examples for the modification for fc and conv layers below.
deflinear_flops_counter_hook(module, input, output):
input = input[0]
output_last_dim = output.shape[-1] # pytorch checks dimensions, so here we don't care much# MODIFICATION HAPPENS HERE
num_zero_weights = (module.weight.data.abs() < 1e-9).sum()
zero_weights_factor = 1 - torch.true_divide(num_zero_weights, module.weight.data.numel())
module.__flops__ += int(np.prod(input.shape) * output_last_dim) * zero_weights_factor.numpy()
# MODIFICATION HAPPENS HERE
defconv_flops_counter_hook(conv_module, input, output):
# Can have multiple inputs, getting the first oneinput = input[0]
batch_size = input.shape[0]
output_dims = list(output.shape[2:])
kernel_dims = list(conv_module.kernel_size)
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
filters_per_channel = out_channels // groups
conv_per_position_flops = int(np.prod(kernel_dims)) * in_channels * filters_per_channel
active_elements_count = batch_size * int(np.prod(output_dims))
# MODIFICATION HAPPENS HERE
num_zero_weights = (conv_module.weight.data.abs() < 1e-9).sum()
zero_weights_factor = 1 - torch.true_divide(num_zero_weights, conv_module.weight.data.numel())
overall_conv_flops = conv_per_position_flops * active_elements_count * zero_weights_factor.numpy()
# MODIFICATION HAPPENS HERE
bias_flops = 0if conv_module.bias isnotNone:
bias_flops = out_channels * active_elements_count
overall_flops = overall_conv_flops + bias_flops
conv_module.__flops__ += int(overall_flops)
Note that I'm using 1e-9 as a threshold for a weight counting as zero.
Post a Comment for "How Can I Calculate Flops And Params Without 0 Weights Neurons Affected?"