Integer Partition with Distinct Parts

Number of ways to write an integer as the sum of a decreasing sequence, e.g., 5=4+1=3+2

Question

How many ways can you write an integer n as the sum of a decreasing sequence(excluding n+0)?

Let’s look at an example from Wikipedia:

Alternatively, we could count partitions in which no number occurs more than once. Such a partition is called a partition with distinct parts. If we count the partitions of 8 with distinct parts, we also obtain 6:

  • 8
  • 7 + 1
  • 6 + 2
  • 5 + 3
  • 5 + 2 + 1
  • 4 + 3 + 1

The problem boils down to find the number of partition with distinct parts and simply subtract 1 from it. So I’ll focus on the core part.

Solution

Doing it by hand is easy. We choose the first element, e.g., 5, figure out 8-5=3, and try to split 3 into smaller parts. For human beings, checking “3 is less than 5” on paper happens in an instant. But we need a way to tell the computer what is the current max, so it knows possible choices for the next number.

Recursion

Visualize the result

Before we directly count the total number of sequences, it helps to visualize the result. Based on the idea from StackOverflow:

Translate partition(n, maxi) into English: Parition the number n into a decreasing sequence where the first element(largest one) is no bigger than maxi.

min(n-i, i-1) is to ensure the sequence is decreasing, and the max is no more than the number itself. To see that, let M=min(ni,i1)Mi1 and MniM = min(n-i, i-1) \Rightarrow M \leq i-1 \text{ and } M \leq n-i.

This also avoids extra calculations because the size for each part can’t be more than what we have. For example, partition(5, 5) = partition(5, 14), so we don’t need to run the loop from 6 to 14.

def partition(n: int, maxi: int)-> List[int]:
    if n == 0: return [[]]
    results = []
    # 1 <= i <= n. If max = 0, we won't have any partition.
    for i in range(maxi, 0, -1):
      	# n-i is what remains after taking away i from n.
        subresults = partition(n - i, min(n-i, i-1))
        results.extend([ [i] + sub for sub in subresults])
    return results
def partition(n: int, maxi: int)-> List[int]:
    if n == 0: return [[]]
    results = []
    # 1 <= i <= n. If max = 0, we won't have any partition.
    for i in range(maxi, 0, -1):
      	# n-i is what remains after taking away i from n.
        subresults = partition(n - i, min(n-i, i-1))
        results.extend([ [i] + sub for sub in subresults])
    return results

> partition(5, 5)
[[5], [4, 1], [3, 2]]
 
> partition(10, 5)
[[5, 4, 1], [5, 3, 2], [4, 3, 2, 1]]
 
> partition(10, 10)
[[10], [9, 1], [8, 2], [7, 3], [7, 2, 1], [6, 4], [6, 3, 1], [5, 4, 1], [5, 3, 2], [4, 3, 2,1]]
> partition(5, 5)
[[5], [4, 1], [3, 2]]
 
> partition(10, 5)
[[5, 4, 1], [5, 3, 2], [4, 3, 2, 1]]
 
> partition(10, 10)
[[10], [9, 1], [8, 2], [7, 3], [7, 2, 1], [6, 4], [6, 3, 1], [5, 4, 1], [5, 3, 2], [4, 3, 2,1]]

Let’s count the total number

Base case is 0. When we have count_p(0, i-1), it means n=i or n=0, so the only possible partition is n = n no matter what the max is.

def count_p(n, maxi):
  if n == 0: return 1
  result = 0
  for i in range(maxi, 0, -1):
    result += count_p(n-i, min(n-i, i-1))
  return result
def count_p(n, maxi):
  if n == 0: return 1
  result = 0
  for i in range(maxi, 0, -1):
    result += count_p(n-i, min(n-i, i-1))
  return result

> count_p(10, 10)
10
> count_p(5, 5)
3
> count_p(200, 200)
... waiting ...
> count_p(10, 10)
10
> count_p(5, 5)
3
> count_p(200, 200)
... waiting ...

It works, but it’s super inefficient. To see where recalculations happen, I modified the function as below:

from collections import Counter
 
def count_p_recalc(n, maxi, pool):
    if n == 0: return pool
    for i in range(maxi, 0, -1):
        pool[(n, i)] += 1
        pool = count_p_recalc(n - i, min(n - i, i - 1), pool)
    return pool
 
print(count_p_recalc(20, 20, Counter()).most_common(5))
 
# Result format: ((n, maxi), #number of recalcualtions)
# [((1, 1), 29), ((2, 1), 25), ((3, 1), 21), ((4, 1), 17), ((2, 2), 15)]
from collections import Counter
 
def count_p_recalc(n, maxi, pool):
    if n == 0: return pool
    for i in range(maxi, 0, -1):
        pool[(n, i)] += 1
        pool = count_p_recalc(n - i, min(n - i, i - 1), pool)
    return pool
 
print(count_p_recalc(20, 20, Counter()).most_common(5))
 
# Result format: ((n, maxi), #number of recalcualtions)
# [((1, 1), 29), ((2, 1), 25), ((3, 1), 21), ((4, 1), 17), ((2, 2), 15)]

If these results are stored somewhere, we can avoid all recalculations and just use it.

Dynamic Programming

def P(n):
 
    size = n + 1
    # initialize the table
    p = [[0 for _ in range(size)] for _ in range(size)]
 
    p[0][0] = 1
    for i in range(size):
        for j in range(1, size):
         		# The same idea: we don't care about more than what we have
            if j <= i:
                p[i][j] = p[i][j - 1] + p[i - j][min(i - j, j - 1)]
 
    return p[n][n]
def P(n):
 
    size = n + 1
    # initialize the table
    p = [[0 for _ in range(size)] for _ in range(size)]
 
    p[0][0] = 1
    for i in range(size):
        for j in range(1, size):
         		# The same idea: we don't care about more than what we have
            if j <= i:
                p[i][j] = p[i][j - 1] + p[i - j][min(i - j, j - 1)]
 
    return p[n][n]

Let’s decipher the highlighted line:

p[i][j]: Partition the integer i into a decreasing sequence with the first element(largest one) **less than or equal to ** j. Here we are accumulating the result. Say we want to partition 8. After selecting 5 as the first element, we want to find all decreasing sequences starting with 4 or less that sum to 8-5=3. As long as the first element is less than 5, the sequence is valid.

p[i][j-1] +: Accumulate the result as explained above.

p[i-j][min(i-j, j-1)]: Find all valid subsequences from i-j(what’s left). This corresponds to our recursion partitoin(n-i, min(n-i, i-1)).

Finally, we can find our result at p[n][n]: partition the integer n with the starting element less than or equal to n. This is exactly what we’re looking for.

If you want to exclude n=n+0, just return p[n][n]-1.

Here’s the table from P(8). The result is at p[8, 8] = 6 as we expected.

[
 [1, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 1, 2, 0, 0, 0, 0, 0],
 [0, 0, 0, 1, 2, 0, 0, 0, 0],
 [0, 0, 0, 1, 2, 3, 0, 0, 0],
 [0, 0, 0, 1, 2, 3, 4, 0, 0],
 [0, 0, 0, 0, 2, 3, 4, 5, 0],
 [0, 0, 0, 0, 1, 3, 4, 5, 6]
]
[
 [1, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 1, 2, 0, 0, 0, 0, 0],
 [0, 0, 0, 1, 2, 0, 0, 0, 0],
 [0, 0, 0, 1, 2, 3, 0, 0, 0],
 [0, 0, 0, 1, 2, 3, 4, 0, 0],
 [0, 0, 0, 0, 2, 3, 4, 5, 0],
 [0, 0, 0, 0, 1, 3, 4, 5, 6]
]

欢迎通过 Twitter 邮件告诉我你的想法
Find me on Twitter or write me an email