[CodeReview] DepGraph : Towards any pruning
Prerequisite :
![[CVPR23] DeepGraph : Towards Any Structural Pruning](https://image.until.blog/namgyu-youn/article/1742485550077.png)

Main Concept : Dependency
Inter-layer Dependency

Inter-layer dependency occurs when layer's input and output are connected psychically (e.g., If there is Conv1 → Conv2
, output of Conv1
is connected to input of Conv2
)
Intra-layer Dependency

Note: This is caused by allocated (private) structure (e.g.,
Mul
operation in Tensor allocates (fixes) the shape)
Therefore, when we want to prune for allocated structure, it should be pruned in structure-level.
c.f., In
ConV
, since input and output channel are independent, there is no intra-layer dependency
Configuration
Mapping function type2class
) configures the allocated structure in Node
(See also : Basic Concepts - Node) like the following:
# Mapping List
TORCH_CONV = nn.modules.conv._ConvNd
TORCH_BATCHNORM = nn.modules.batchnorm._BatchNorm
...
# Index mapping
class OPTYPE(IntEnum):
CONV = 0
BN = 1
(...)
# Configure module and layer's type
def module2type(module):
if isinstance(module, TORCH_CONV):
if module.groups == module.out_channels and module.out_channels > 1:
return OPTYPE.DEPTHWISE_CONV
else:
return OPTYPE.CONV
elif isinstance(module, TORCH_BATCHNORM):
return OPTYPE.BN
def type2class(op_type):
if op_type == OPTYPE.CONV or op_type==OPTYPE.DEPTHWISE_CONV:
return TORCH_CONV
elif op_type == OPTYPE.BN:
return TORCH_BATCHNORM
After mapping, the operation's type is called in Node class.
Case 01 : Batch Normalization
BN shares
input channel
since it is allocated based on mini-batch. Therefore, if you want to prune in BN, you should approach channel-level pruning.
Case 02 : Element-wise Operations
If tensor's operations are applied using element wise (i.e., Add
, Sub
, Mul
), tensors's shape is allocated.
Let's suppose Tensor (N, C, H, W) : batch N, channels C, depth D, height H, width W.
If we want to prune k-th channel, we should prune (:, k, :, :) for all Tensors which have intra-layer dependency.
Main Concept : Grouping
class Group(object):
def __init__(self):
self._group = list() # Pair list : (Dependency, Indices)
def add_dep(self, dep, idxs):
"Add dependency in configured index"
def add_and_merge(self, dep, idxs):
"Concatenate dependency if add_dep has already called"
Dependency is added to Group
based on Add
and Concatenate
.
Example : ResNet (figure 3)
Let's suppose the architecture is ResNet (Figure 3). Then utility function group.details()
) would visualize the group pruning like thefollowing:
> group.details()
--------------------------------
Pruning Group
--------------------------------
[0] Conv4 → Conv4, idxs (2) =[0, 1] (Pruning start point, none dependency)
[1] Conv4 → BN5, idxs (2) =[0, 1] | Inter
[2] BN5 → BN5, idxs (2) =[0, 1] | Intra : channel-level pruning
[3] BN5 → ReLU6, idxs (2) =[0, 1] | Inter
[4] ReLU6 → ReLU6, idxs (2) =[0, 1] | Intra : channel-level
[5] ReLU6 → Add7, idxs (2) =[0, 1] | Inter

Basic Concepts (Graph and Dependency)
DependencyGraph
class Dependency(object):
def __init__(self, trigger, handler, source: Node, target: Node):
self.trigger, self.handler = trigger, handler # Trigger & Handler pattern
self.source, self.target = source, target # Start & End node
# index_mapping are used to map the indices of the source node to the target node
# There will be two index_mapping functions for each dependency, to handle cascaded concat & split operations.
# E.g. split -> concat
# We first map the indeces to the splited tensor with index_mapping[0],
# then map the splited tensor to the concatenated tensor with index_mapping[1].
# Current coordinate system => Standard coordinate system => target coordinate system
# index_mapping[0] index_mapping[1]
# 이게 무슨 소리야...?????? (omitted)
self.index_mapping = [INDEX_MAPPING_PLACEHOLDER, INDEX_MAPPING_PLACEHOLDER]
Inter-layer dependency : Dependency caused by the **network** connection (Input & output layer)
Intra-layer Dependency : Dependency caused by **shared** pruning schemes even though they are not connected
Groupped Parameters
class Group(object):
def __init__(self):
self._group = list() # Pair based list : (Dependency, Index)
self._DG = None # DepGraph
def prune(self, idxs=None, record_history=True):
"""Group-level pruning"""
Group-level parameters with common dependency should be pruned in same time.
Node
class Node(object):
def __init__(self, module: nn.Module, grad_fn, name: str = None):
self.inputs, self.outputs = [], [] # Network Decomposition (Input/Output)
self._name = name # Type of layer (e.g., ConV, BN)
self.type = ops.module2type(module) # Type of Node (layer)
# For Dependency Modeling
self.dependencies = [] # List of dependency (object)
Dependency
class Dependency(object):
def __init__(self, trigger, handler, source: Node, target: Node):
self.trigger, self.handler = trigger, handler # Trigger & Handler pattern
self.source, self.target = source, target # Start & End node
# index_mapping are used to map the indices of the source node to the target node
# There will be two index_mapping functions for each dependency, to handle cascaded concat & split operations.
# E.g. split -> concat
# We first map the indeces to the splited tensor with index_mapping[0],
# then map the splited tensor to the concatenated tensor with index_mapping[1].
# Current coordinate system => Standard coordinate system => target coordinate system
# index_mapping[0] index_mapping[1]
self.index_mapping = [INDEX_MAPPING_PLACEHOLDER, INDEX_MAPPING_PLACEHOLDER]
Trigger & Handler pattern
e.g., If dependency exist like
f- -> f+
, thenhandler(f+)
is called bytrigger(f-)