[CodeReview] DepGraph : Towards any pruning

Code Review for the DepGraph: Towards any pruning
#TinyML#pruning
avatar
2025.05.16
·
5 min read

Prerequisite :

[CVPR23] DeepGraph : Towards Any Structural Pruning
DepGraph는 기존의 Structural Pruning에서 group-level pruning이 전체 schema를 고려하지 않기 때문에 발생하는 문제를 극복하고자 한 연구입니다.
https://until.blog/@namgyu-youn/-paper--deepgraph---towards-any-structural-pruning-rdre2g47
[CVPR23] DeepGraph : Towards Any Structural Pruning
Gamma
Gamma is your free-to-use AI design partner for creating effortless presentations, websites, and more. No coding or design skills required.
https://gamma.app/docs/DepGraph-Towards-Any-Structural-Pruning-ijgzt91octaix88
Gamma

Main Concept : Dependency

Inter-layer Dependency

6143

Inter-layer dependency occurs when layer's input and output are connected psychically (e.g., If there is Conv1 → Conv2 , output of Conv1 is connected to input of Conv2)

Intra-layer Dependency

6144
  • Note: This is caused by allocated (private) structure (e.g., Mul operation in Tensor allocates (fixes) the shape)

  • Therefore, when we want to prune for allocated structure, it should be pruned in structure-level.

c.f., In ConV, since input and output channel are independent, there is no intra-layer dependency

Configuration

Mapping function type2class) configures the allocated structure in Node (See also : Basic Concepts - Node) like the following:

# Mapping List
TORCH_CONV = nn.modules.conv._ConvNd
TORCH_BATCHNORM = nn.modules.batchnorm._BatchNorm
...

# Index mapping
class OPTYPE(IntEnum):
	CONV = 0
	BN = 1
	(...)

# Configure module and layer's type
def module2type(module):
	if isinstance(module, TORCH_CONV):
		if module.groups == module.out_channels and module.out_channels > 1:
			return OPTYPE.DEPTHWISE_CONV
		else:
			return OPTYPE.CONV
	elif isinstance(module, TORCH_BATCHNORM):
		return OPTYPE.BN

def type2class(op_type):
	if op_type == OPTYPE.CONV or op_type==OPTYPE.DEPTHWISE_CONV:
		return TORCH_CONV
	elif op_type == OPTYPE.BN:
		return TORCH_BATCHNORM
  • After mapping, the operation's type is called in Node class.

Case 01 : Batch Normalization

  • BN shares input channel since it is allocated based on mini-batch. Therefore, if you want to prune in BN, you should approach channel-level pruning.

Case 02 : Element-wise Operations

If tensor's operations are applied using element wise (i.e., Add, Sub, Mul), tensors's shape is allocated.

Let's suppose Tensor (N, C, H, W) : batch N, channels C, depth D, height H, width W.

If we want to prune k-th channel, we should prune (:, k, :, :) for all Tensors which have intra-layer dependency.


Main Concept : Grouping

class Group(object):
    def __init__(self):
        self._group = list()  # Pair list : (Dependency, Indices)
	
	def add_dep(self, dep, idxs):
		"Add dependency in configured index"
	
	def add_and_merge(self, dep, idxs):
		"Concatenate dependency if add_dep has already called"

Dependency is added to Group based on Add and Concatenate.

Example : ResNet (figure 3)

Let's suppose the architecture is ResNet (Figure 3). Then utility function group.details()) would visualize the group pruning like thefollowing:

> group.details()
--------------------------------
          Pruning Group
--------------------------------
[0] Conv4 → Conv4, idxs (2) =[0, 1]  (Pruning start point, none dependency)
[1] Conv4 → BN5, idxs (2) =[0, 1]    | Inter
[2] BN5 → BN5, idxs (2) =[0, 1]      | Intra : channel-level pruning
[3] BN5 → ReLU6, idxs (2) =[0, 1]    | Inter
[4] ReLU6 → ReLU6, idxs (2) =[0, 1]  | Intra : channel-level
[5] ReLU6 → Add7, idxs (2) =[0, 1]   | Inter
6145

Basic Concepts (Graph and Dependency)

DependencyGraph

class Dependency(object):
    def __init__(self, trigger, handler, source: Node, target: Node):
        self.trigger, self.handler = trigger, handler  # Trigger & Handler pattern
        self.source, self.target = source, target    # Start & End node

		# index_mapping are used to map the indices of the source node to the target node
		# There will be two index_mapping functions for each dependency, to handle cascaded concat & split operations.
		# E.g. split -> concat
		# We first map the indeces to the splited tensor with index_mapping[0],
		# then map the splited tensor to the concatenated tensor with index_mapping[1].
		# Current coordinate system => Standard coordinate system => target coordinate system
		# index_mapping[0] index_mapping[1]
		# 이게 무슨 소리야...?????? (omitted)
        self.index_mapping = [INDEX_MAPPING_PLACEHOLDER, INDEX_MAPPING_PLACEHOLDER]
  • Inter-layer dependency : Dependency caused by the **network** connection (Input & output layer)

  • Intra-layer Dependency : Dependency caused by **shared** pruning schemes even though they are not connected

Groupped Parameters

class Group(object):
    def __init__(self):
        self._group = list()  # Pair based list : (Dependency, Index)
        self._DG = None      # DepGraph
    
    def prune(self, idxs=None, record_history=True):
        """Group-level pruning"""
  • Group-level parameters with common dependency should be pruned in same time.

Node

class Node(object):
    def __init__(self, module: nn.Module, grad_fn, name: str = None):
        self.inputs, self.outputs = [], []    # Network Decomposition (Input/Output)
        self._name = name # Type of layer (e.g., ConV, BN)
        self.type = ops.module2type(module) # Type of Node (layer)
		    
        # For Dependency Modeling
        self.dependencies = []  # List of dependency (object)

Dependency

class Dependency(object):
    def __init__(self, trigger, handler, source: Node, target: Node):
        self.trigger, self.handler = trigger, handler  # Trigger & Handler pattern
        self.source, self.target = source, target    # Start & End node

		# index_mapping are used to map the indices of the source node to the target node
		# There will be two index_mapping functions for each dependency, to handle cascaded concat & split operations.
		# E.g. split -> concat
		# We first map the indeces to the splited tensor with index_mapping[0],
		# then map the splited tensor to the concatenated tensor with index_mapping[1].
		# Current coordinate system => Standard coordinate system => target coordinate system
		# index_mapping[0] index_mapping[1]
        self.index_mapping = [INDEX_MAPPING_PLACEHOLDER, INDEX_MAPPING_PLACEHOLDER]
  • Trigger & Handler pattern

  • e.g., If dependency exist like f- -> f+, then handler(f+) is called by trigger(f-)







- 컬렉션 아티클