• Feed
  • Explore
  • Ranking
/
/
    TinyML

    [CodeReview] DepGraph : Towards any pruning

    Code Review for the DepGraph: Towards any pruning
    #TinyML#pruning
    n
    namgyu-youn
    2025.05.16
    ·
    5 min read

    Prerequisite :

    [CVPR23] DeepGraph : Towards Any Structural Pruning
    DepGraph는 기존의 Structural Pruning에서 group-level pruning이 전체 schema를 고려하지 않기 때문에 발생하는 문제를 극복하고자 한 연구입니다.
    https://until.blog/@namgyu-youn/-paper--deepgraph---towards-any-structural-pruning-rdre2g47
    [CVPR23] DeepGraph : Towards Any Structural Pruning
    Gamma
    Gamma is your free-to-use AI design partner for creating effortless presentations, websites, and more. No coding or design skills required.
    https://gamma.app/docs/DepGraph-Towards-Any-Structural-Pruning-ijgzt91octaix88
    Gamma

    Main Concept : Dependency

    Inter-layer Dependency

    6143

    Inter-layer dependency occurs when layer's input and output are connected psychically (e.g., If there is Conv1 → Conv2 , output of Conv1 is connected to input of Conv2)

    Intra-layer Dependency

    6144
    • Note: This is caused by allocated (private) structure (e.g., Mul operation in Tensor allocates (fixes) the shape)

    • Therefore, when we want to prune for allocated structure, it should be pruned in structure-level.

    c.f., In ConV, since input and output channel are independent, there is no intra-layer dependency

    Configuration

    Mapping function type2class) configures the allocated structure in Node (See also : Basic Concepts - Node) like the following:

    # Mapping List
    TORCH_CONV = nn.modules.conv._ConvNd
    TORCH_BATCHNORM = nn.modules.batchnorm._BatchNorm
    ...
    
    # Index mapping
    class OPTYPE(IntEnum):
    	CONV = 0
    	BN = 1
    	(...)
    
    # Configure module and layer's type
    def module2type(module):
    	if isinstance(module, TORCH_CONV):
    		if module.groups == module.out_channels and module.out_channels > 1:
    			return OPTYPE.DEPTHWISE_CONV
    		else:
    			return OPTYPE.CONV
    	elif isinstance(module, TORCH_BATCHNORM):
    		return OPTYPE.BN
    
    def type2class(op_type):
    	if op_type == OPTYPE.CONV or op_type==OPTYPE.DEPTHWISE_CONV:
    		return TORCH_CONV
    	elif op_type == OPTYPE.BN:
    		return TORCH_BATCHNORM
    • After mapping, the operation's type is called in Node class.

    Case 01 : Batch Normalization

    • BN shares input channel since it is allocated based on mini-batch. Therefore, if you want to prune in BN, you should approach channel-level pruning.

    Case 02 : Element-wise Operations

    If tensor's operations are applied using element wise (i.e., Add, Sub, Mul), tensors's shape is allocated.

    Let's suppose Tensor (N, C, H, W) : batch N, channels C, depth D, height H, width W.

    If we want to prune k-th channel, we should prune (:, k, :, :) for all Tensors which have intra-layer dependency.


    Main Concept : Grouping

    class Group(object):
        def __init__(self):
            self._group = list()  # Pair list : (Dependency, Indices)
    	
    	def add_dep(self, dep, idxs):
    		"Add dependency in configured index"
    	
    	def add_and_merge(self, dep, idxs):
    		"Concatenate dependency if add_dep has already called"

    Dependency is added to Group based on Add and Concatenate.

    Example : ResNet (figure 3)

    Let's suppose the architecture is ResNet (Figure 3). Then utility function group.details()) would visualize the group pruning like thefollowing:

    > group.details()
    --------------------------------
              Pruning Group
    --------------------------------
    [0] Conv4 → Conv4, idxs (2) =[0, 1]  (Pruning start point, none dependency)
    [1] Conv4 → BN5, idxs (2) =[0, 1]    | Inter
    [2] BN5 → BN5, idxs (2) =[0, 1]      | Intra : channel-level pruning
    [3] BN5 → ReLU6, idxs (2) =[0, 1]    | Inter
    [4] ReLU6 → ReLU6, idxs (2) =[0, 1]  | Intra : channel-level
    [5] ReLU6 → Add7, idxs (2) =[0, 1]   | Inter
    6145

    Basic Concepts (Graph and Dependency)

    DependencyGraph

    class Dependency(object):
        def __init__(self, trigger, handler, source: Node, target: Node):
            self.trigger, self.handler = trigger, handler  # Trigger & Handler pattern
            self.source, self.target = source, target    # Start & End node
    
    		# index_mapping are used to map the indices of the source node to the target node
    		# There will be two index_mapping functions for each dependency, to handle cascaded concat & split operations.
    		# E.g. split -> concat
    		# We first map the indeces to the splited tensor with index_mapping[0],
    		# then map the splited tensor to the concatenated tensor with index_mapping[1].
    		# Current coordinate system => Standard coordinate system => target coordinate system
    		# index_mapping[0] index_mapping[1]
    		# 이게 무슨 소리야...?????? (omitted)
            self.index_mapping = [INDEX_MAPPING_PLACEHOLDER, INDEX_MAPPING_PLACEHOLDER]
    • Inter-layer dependency : Dependency caused by the **network** connection (Input & output layer)

    • Intra-layer Dependency : Dependency caused by **shared** pruning schemes even though they are not connected

    Groupped Parameters

    class Group(object):
        def __init__(self):
            self._group = list()  # Pair based list : (Dependency, Index)
            self._DG = None      # DepGraph
        
        def prune(self, idxs=None, record_history=True):
            """Group-level pruning"""
    • Group-level parameters with common dependency should be pruned in same time.

    Node

    class Node(object):
        def __init__(self, module: nn.Module, grad_fn, name: str = None):
            self.inputs, self.outputs = [], []    # Network Decomposition (Input/Output)
            self._name = name # Type of layer (e.g., ConV, BN)
            self.type = ops.module2type(module) # Type of Node (layer)
    		    
            # For Dependency Modeling
            self.dependencies = []  # List of dependency (object)

    Dependency

    class Dependency(object):
        def __init__(self, trigger, handler, source: Node, target: Node):
            self.trigger, self.handler = trigger, handler  # Trigger & Handler pattern
            self.source, self.target = source, target    # Start & End node
    
    		# index_mapping are used to map the indices of the source node to the target node
    		# There will be two index_mapping functions for each dependency, to handle cascaded concat & split operations.
    		# E.g. split -> concat
    		# We first map the indeces to the splited tensor with index_mapping[0],
    		# then map the splited tensor to the concatenated tensor with index_mapping[1].
    		# Current coordinate system => Standard coordinate system => target coordinate system
    		# index_mapping[0] index_mapping[1]
            self.index_mapping = [INDEX_MAPPING_PLACEHOLDER, INDEX_MAPPING_PLACEHOLDER]
    • Trigger & Handler pattern

    • e.g., If dependency exist like f- -> f+, then handler(f+) is called by trigger(f-)







    - 컬렉션 아티클