Number of Submatrices that Sum to Target

Explanations with pictures

This article is a complete analysis of 1074. Number of submatrices that sum to target from LeetCode.

Brute force

Consider every possible [a, b](top-left position) and [c, d] (bottom-right position) pair, and check their sums.

Time complexity

Without loss of generality, we assume the given matrix is a square matrix with size nxn.

For each a, we have n choices for b, and we have n choices for a —> n2n^2.

For each [a, b], we have n*n pairs of [c, d]. —> n2n2=n4n^2 * n^2 = n^4

However, each time we need to sum up all entries in the submatrix, giving us O(n6)O(n^6). This is unacceptable for large inputs.

Optimize the brute force strategy

Do we need to sum up all entries in a submatrix every time? Can we reuse some computations?

Observation from a simple case: if the “height” of the matrix is 1, it is essentially an array. Let sum[i] denote the sum of the subarray from index 0 to i. If we want to know sum[3], we only need to add the third element to sum[2]. So we can store sum[i], and the addition only costs O(1)O(1).

We can extend the idea to 2D to reuse computations:

Imgur

To get sum[i, j], we use the following loop:

for i in range(rowsize):
    for j in range(colsize):
        if i==0 && j==0:
            ps[i][j] = matrix[i][j]
        else if i==0 && j>0: # first row
            ps[i][j] = p[i][j-1] + matrix[i][j]
        else if i>0 && j==0: # first col
            ps[i][j] = p[i-1][j] + matrix[i][j]
        else:
            ps[i][j] = p[i-1][j] + p[i][j-1] - p[i-1][j-1] + matrix[i][j]
for i in range(rowsize):
    for j in range(colsize):
        if i==0 && j==0:
            ps[i][j] = matrix[i][j]
        else if i==0 && j>0: # first row
            ps[i][j] = p[i][j-1] + matrix[i][j]
        else if i>0 && j==0: # first col
            ps[i][j] = p[i-1][j] + matrix[i][j]
        else:
            ps[i][j] = p[i-1][j] + p[i][j-1] - p[i-1][j-1] + matrix[i][j]

Time complexity

Now, we can get the sum of any submatrix in O(1)O(1) use the formula, but we still have O(n4)O(n^4) pairs of “top-left” and “bottom-right” which can’t be reduced further. The total cost is O(n4)O(n^4).

“Compress” rows to 1D

Number of subarrays that sum to target

It’s always easier to start with a simpler version of the problem and extend the idea.

The 1D counterpart of this problem is “number of subarrays that sum to target”.

The idea is from the discussion of 560. Subarray Sum Equals K . It is amazing how a key observation reduced the cost to O(n)O(n).

Let sum[i, j] denote the sum from index i to j.

We observe that the any sum[i, j] can be broken into two parts:

sum[i,j]=sum[0,j]sum[i1]Assume sum[i,j]=k,the given targetk=sum[0,j]sum[i1]sum[i1]=sum[0,j]k\begin{aligned} \text{sum}[i, j] &= \text{sum}[0, j] - \text{sum}[i-1]\\ \text{Assume sum}[i, j]& = k, \text{the given target}\\ k &= \text{sum}[0, j] - \text{sum}[i-1]\\ \text{sum}[i-1] &=\text{sum}[0, j]-k \end{aligned}

Imgur

The key observation is: #occurrences of sum[0, j]-k is equivalent to the number of subarrays that sum to k among all possible subarrays between 0 and j. Let’s consider an example:

Imgur

from collections import Counter
class Solution:
    def subarraySum(self, nums: List[int], k: int) -> int:
        result, cur_sum = 0, 0
        # counter stores the value of a sum and its number of occurences
        # Adding {0:1} means if sum = k, sum-k = 0, increment by 1
        counter = Counter({0:1})
        for j in range(len(nums)):
            # add the new number to the exsiting sum
            # cur_sum stores the sum of subarray[0, j]
            cur_sum += nums[i]
            # increment the result by the number of subarrays that sum to cur_sum - k
            result += counter[cur_sum - k]
            # store the current sum in counter
            counter[cur_sum] += 1
        return result
from collections import Counter
class Solution:
    def subarraySum(self, nums: List[int], k: int) -> int:
        result, cur_sum = 0, 0
        # counter stores the value of a sum and its number of occurences
        # Adding {0:1} means if sum = k, sum-k = 0, increment by 1
        counter = Counter({0:1})
        for j in range(len(nums)):
            # add the new number to the exsiting sum
            # cur_sum stores the sum of subarray[0, j]
            cur_sum += nums[i]
            # increment the result by the number of subarrays that sum to cur_sum - k
            result += counter[cur_sum - k]
            # store the current sum in counter
            counter[cur_sum] += 1
        return result

From 1D to 2D

If we can control one dimension, we will reduce the matrix problem to the simpler array problem.

We first compute the prefix sum for each row, and we fix col i an col j to “convert” the submatrix to a subarray.

Imgur

from collections import Counter
class Solution:
    def numSubmatrixSumTarget(self, A: List[List[int]], target: int) -> int:
        result = 0
        rowsize, colsize = len(A), len(A[0])
        # get prefix sum for each row
        for row in A:
            for j in range(1, colsize):
                row[j] += row[j-1]
 
        # left and right col define the width of the matrix
        for leftcol in range(colsize):
            for rightcol in range(leftcol, colsize):
                counter = Counter({0:1})
                cur_sum = 0
                for r in range(rowsize):
                    cur_sum += A[r][rightcol] - (A[r][leftcol-1] if leftcol>0 else 0)
                    result += counter[cur_sum - target]
                    counter[cur_sum] += 1
        return result
from collections import Counter
class Solution:
    def numSubmatrixSumTarget(self, A: List[List[int]], target: int) -> int:
        result = 0
        rowsize, colsize = len(A), len(A[0])
        # get prefix sum for each row
        for row in A:
            for j in range(1, colsize):
                row[j] += row[j-1]
 
        # left and right col define the width of the matrix
        for leftcol in range(colsize):
            for rightcol in range(leftcol, colsize):
                counter = Counter({0:1})
                cur_sum = 0
                for r in range(rowsize):
                    cur_sum += A[r][rightcol] - (A[r][leftcol-1] if leftcol>0 else 0)
                    result += counter[cur_sum - target]
                    counter[cur_sum] += 1
        return result

Time complexity

  1. Get prefix sum for each row —> O(n2)O(n^2)
  2. For each pair of (i, j), we need O(n)O(n) to determine the submatrices. There are O(n2)O(n^2) pairs —> O(n3)O(n^3)

Total cost: O(n3)O(n^3)

The similar idea of manipulating the equation can also be applied to any problem associated with “find the number of sub-patterns that satisfy a given condition”.

欢迎通过 Twitter 邮件告诉我你的想法
Find me on Twitter or write me an email