Source code for ezflow.engine.pruning

from torch.nn.utils import prune


[docs]def prune_l1_unstructured(model, layer_type, proportion): """ L1 unstructured pruning Parameters ---------- model : torch.nn.Module The model to prune layer_type : torch.nn.Module The layer type to prune proportion : float The proportion of weights to prune """ for module in model.modules(): if isinstance(module, layer_type): prune.l1_unstructured(module, "weight", proportion) prune.remove(module, "weight") return model
[docs]def prune_l1_structured(model, layer_type, proportion): """ L1 structured pruning Parameters ---------- model : torch.nn.Module The model to prune layer_type : torch.nn.Module The layer type to prune proportion : float The proportion of weights to prune """ for module in model.modules(): if isinstance(module, layer_type): prune.ln_structured(module, "weight", proportion, n=1, dim=1) prune.remove(module, "weight") return model