Prerequisite :
![[CVPR23] DeepGraph : Towards Any Structural Pruning](https://image.until.blog/namgyu-youn/article/1742485550077.png)

Main Concept : Dependency
Inter-layer Dependency

Inter-layer dependency occurs when layer's input and output are connected psychically (e.g., If there is Conv1 → Conv2 , output of Conv1 is connected to input of Conv2)
Intra-layer Dependency

Note: This is caused by allocated (private) structure (e.g.,
Muloperation in Tensor allocates (fixes) the shape)
Therefore, when we want to prune for allocated structure, it should be pruned in structure-level.
c.f., In
ConV, since input and output channel are independent, there is no intra-layer dependency
Configuration
Mapping function type2class) configures the allocated structure in Node (See also : Basic Concepts - Node) like the following:
# Mapping List
TORCH_CONV = nn.modules.conv._ConvNd
TORCH_BATCHNORM = nn.modules.batchnorm._BatchNorm
...
# Index mapping
class OPTYPE(IntEnum):
CONV = 0
BN = 1
(...)
# Configure module and layer's type
def module2type(module):
if isinstance(module, TORCH_CONV):
if module.groups == module.out_channels and module.out_channels > 1:
return OPTYPE.DEPTHWISE_CONV
else:
return OPTYPE.CONV
elif isinstance(module, TORCH_BATCHNORM):
return OPTYPE.BN
def type2class(op_type):
if op_type == OPTYPE.CONV or op_type==OPTYPE.DEPTHWISE_CONV:
return TORCH_CONV
elif op_type == OPTYPE.BN:
return TORCH_BATCHNORMAfter mapping, the operation's type is called in Node class.
Case 01 : Batch Normalization
BN shares
input channelsince it is allocated based on mini-batch. Therefore, if you want to prune in BN, you should approach channel-level pruning.
Case 02 : Element-wise Operations
If tensor's operations are applied using element wise (i.e., Add, Sub, Mul), tensors's shape is allocated.
Let's suppose Tensor (N, C, H, W) : batch N, channels C, depth D, height H, width W.
If we want to prune k-th channel, we should prune (:, k, :, :) for all Tensors which have intra-layer dependency.
Main Concept : Grouping
class Group(object):
def __init__(self):
self._group = list() # Pair list : (Dependency, Indices)
def add_dep(self, dep, idxs):
"Add dependency in configured index"
def add_and_merge(self, dep, idxs):
"Concatenate dependency if add_dep has already called"Dependency is added to Group based on Add and Concatenate.
Example : ResNet (figure 3)
Let's suppose the architecture is ResNet (Figure 3). Then utility function group.details()) would visualize the group pruning like thefollowing:
> group.details()
--------------------------------
Pruning Group
--------------------------------
[0] Conv4 → Conv4, idxs (2) =[0, 1] (Pruning start point, none dependency)
[1] Conv4 → BN5, idxs (2) =[0, 1] | Inter
[2] BN5 → BN5, idxs (2) =[0, 1] | Intra : channel-level pruning
[3] BN5 → ReLU6, idxs (2) =[0, 1] | Inter
[4] ReLU6 → ReLU6, idxs (2) =[0, 1] | Intra : channel-level
[5] ReLU6 → Add7, idxs (2) =[0, 1] | Inter
Basic Concepts (Graph and Dependency)
DependencyGraph
class Dependency(object):
def __init__(self, trigger, handler, source: Node, target: Node):
self.trigger, self.handler = trigger, handler # Trigger & Handler pattern
self.source, self.target = source, target # Start & End node
# index_mapping are used to map the indices of the source node to the target node
# There will be two index_mapping functions for each dependency, to handle cascaded concat & split operations.
# E.g. split -> concat
# We first map the indeces to the splited tensor with index_mapping[0],
# then map the splited tensor to the concatenated tensor with index_mapping[1].
# Current coordinate system => Standard coordinate system => target coordinate system
# index_mapping[0] index_mapping[1]
# 이게 무슨 소리야...?????? (omitted)
self.index_mapping = [INDEX_MAPPING_PLACEHOLDER, INDEX_MAPPING_PLACEHOLDER]Inter-layer dependency : Dependency caused by the **network** connection (Input & output layer)
Intra-layer Dependency : Dependency caused by **shared** pruning schemes even though they are not connected
Groupped Parameters
class Group(object):
def __init__(self):
self._group = list() # Pair based list : (Dependency, Index)
self._DG = None # DepGraph
def prune(self, idxs=None, record_history=True):
"""Group-level pruning"""Group-level parameters with common dependency should be pruned in same time.
Node
class Node(object):
def __init__(self, module: nn.Module, grad_fn, name: str = None):
self.inputs, self.outputs = [], [] # Network Decomposition (Input/Output)
self._name = name # Type of layer (e.g., ConV, BN)
self.type = ops.module2type(module) # Type of Node (layer)
# For Dependency Modeling
self.dependencies = [] # List of dependency (object)Dependency
class Dependency(object):
def __init__(self, trigger, handler, source: Node, target: Node):
self.trigger, self.handler = trigger, handler # Trigger & Handler pattern
self.source, self.target = source, target # Start & End node
# index_mapping are used to map the indices of the source node to the target node
# There will be two index_mapping functions for each dependency, to handle cascaded concat & split operations.
# E.g. split -> concat
# We first map the indeces to the splited tensor with index_mapping[0],
# then map the splited tensor to the concatenated tensor with index_mapping[1].
# Current coordinate system => Standard coordinate system => target coordinate system
# index_mapping[0] index_mapping[1]
self.index_mapping = [INDEX_MAPPING_PLACEHOLDER, INDEX_MAPPING_PLACEHOLDER]Trigger & Handler pattern
e.g., If dependency exist like
f- -> f+, thenhandler(f+)is called bytrigger(f-)