Skip to content

235. Partition to K Equal Sum Subsets

🔗 LeetCode Problem: 698. Partition to K Equal Sum Subsets
📊 Difficulty: Hard
🏷️ Topics: Dynamic Programming, Backtracking, Bit Manipulation, Memoization, Bitmask DP

Problem Statement

Given an integer array nums and an integer k, return true if it is possible to divide this array into k non-empty subsets whose sums are all equal.

Example 1:

Input: nums = [4,3,2,3,5,2,1], k = 4
Output: true
Explanation: It is possible to divide it into 4 subsets (5), (1,4), (2,3), (2,3) with equal sums.

Example 2:

Input: nums = [1,2,3,4], k = 3
Output: false
Explanation: It is impossible to divide it into 3 equal sum subsets.

Constraints: - 1 <= k <= nums.length <= 16 - 1 <= nums[i] <= 10^4 - The frequency of each element is in the range [1, 4]


🔍 Let's Discover the Pattern Together

Start with Small Examples

Let's not jump to the solution. Let's understand what's really happening by trying small examples:

Example 1: nums = [4, 3, 2, 3, 5, 2, 1], k = 4
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Total sum = 4+3+2+3+5+2+1 = 20
Want: 4 equal subsets

If 4 subsets with equal sums:
  Each subset sum = 20/4 = 5

Now the question becomes:
  "Can we form 4 subsets, each with sum 5?"

Let's try:
  Subset 1: {5} → sum = 5 ✓
  Subset 2: {4, 1} → sum = 5 ✓
  Subset 3: {3, 2} → sum = 5 ✓
  Subset 4: {3, 2} → sum = 5 ✓

Yes! It's possible! ✓

Example 2: nums = [1, 2, 3, 4], k = 3
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Total sum = 1+2+3+4 = 10
Want: 3 equal subsets

Each subset sum = 10/3 = 3.33...

STOP! 🚫
  Sum is not divisible by k!
  Can't have equal integer sums!

Answer: false ✗

Example 3: nums = [2, 2, 2, 2, 2, 2], k = 3
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Total sum = 12
Each subset sum = 12/3 = 4

Can we form 3 subsets with sum 4?
  Subset 1: {2, 2} → sum = 4 ✓
  Subset 2: {2, 2} → sum = 4 ✓
  Subset 3: {2, 2} → sum = 4 ✓

Yes! ✓

Notice the Pattern

Key observations:

1. First check: Is total divisible by k?
   If NO → impossible!

2. Target sum = total / k
   Each subset must sum to this target

3. Problem transforms to:
   "Can we group all numbers into k groups,
    where each group sums to target?"

4. This is like filling k buckets!
   Each bucket needs to reach exactly 'target'

The Bucket-Filling Intuition

Think of it as k buckets:

Buckets: [  ] [  ] [  ] [  ]  (k = 4)
Target:   5    5    5    5

Numbers to distribute: 4, 3, 2, 3, 5, 2, 1

For each number, pick a bucket!
  Number 4 → Bucket 0: [4, _, _]
  Number 3 → Bucket 1: [3, _, _]
  Number 2 → Bucket 0: [4, 2, _]  (now 6 > 5 ✗)
           → Bucket 2: [2, _, _]  (try this!)
  ...

Try all possibilities!
This is BACKTRACKING! 🔑

💡 The AHA Moment - It's a Backtracking Problem!

Connection to Previous Problems

Problem 228: Partition Equal Subset Sum (k=2)
  → Can we make 2 equal groups?
  → Used DP (subset sum)

Problem 233: Last Stone Weight II (k=2)
  → Minimize difference between 2 groups
  → Used DP (subset sum)

THIS Problem: k can be > 2!
  → Can't use simple DP
  → Need BACKTRACKING (try all assignments)

The jump from k=2 to k>2 is SIGNIFICANT! 🔥

Why Not DP?

For k=2:
  State: "Can we make sum X?"
  DP: dp[sum] = true/false
  Works because only 2 groups!

For k>2:
  Need to track MULTIPLE group sums simultaneously
  dp[sum1][sum2][sum3]...? Too many dimensions!

  Better approach: Try all assignments!
  → BACKTRACKING

🔴 Approach 1: Simple Backtracking (What We Discover First)

📐 Function Definition

Function Signature:

boolean recursive(int[] nums, int k, int index, int[] buckets)

What does this function compute? - Input index: Current number we're trying to place (nums[index]) - Input buckets: Array of k buckets, buckets[i] = current sum in bucket i - Input nums: The array of numbers - Input k: Number of buckets - Output: true if we can distribute remaining numbers such that all buckets become equal

What does the recursive call represent? - For current number nums[index], try placing it in EACH of the k buckets - For each bucket i: - Place nums[index] in bucket i - Recursively try to place remaining numbers (index+1) - If successful, return true - Otherwise, backtrack (remove from bucket i and try next bucket) - If no bucket works, return false

The Recursive Structure:

Base case:
  If index == n (all numbers placed):
    Check if all buckets have the SAME sum
    return true if equal, false otherwise

For current number nums[index]:
  Try each bucket i from 0 to k-1:
    buckets[i] += nums[index]  // Place in bucket i

    if recursive(index+1, buckets):  // Try remaining numbers
      return true  // Found solution!

    buckets[i] -= nums[index]  // Backtrack

  return false  // No bucket worked

Why this works:

We systematically try all possible assignments:
  - Each number can go into any of k buckets
  - We explore all k^n combinations
  - When we've placed all numbers, check if buckets are equal
  - Backtrack when necessary

This is exhaustive search!

💡 Intuition

Think of placing items one by one:

nums = [4, 3, 2, 1], k = 2

Start with empty buckets: [0, 0]

Place 4:
  Try bucket 0: [4, 0]
    Place 3:
      Try bucket 0: [7, 0]
        Place 2:
          Try bucket 0: [9, 0]
            Place 1:
              Try bucket 0: [10, 0]
                All placed! Check: 10 ≠ 0 ✗
              Try bucket 1: [9, 1]
                All placed! Check: 9 ≠ 1 ✗
            Backtrack!
          Try bucket 1: [9, 2]
            ... (continue)
      Try bucket 1: [4, 3]
        ... (continue)

Tree of possibilities!
Explore until we find a solution or exhaust all options.

📝 Implementation (Simple Version - What We Discover First)

class Solution {
    public boolean canPartitionKSubsets(int[] nums, int k) {
        int total = 0;
        for (int num : nums) {
            total += num;
        }

        // Early check: divisible?
        if (total % k != 0) return false;

        int[] buckets = new int[k];
        return recursive(nums, k, 0, buckets);
    }

    private boolean recursive(int[] nums, int k, int index, int[] buckets) {
        // Base case: all numbers placed
        if (index == nums.length) {
            // Check if all buckets have same sum
            for (int i = 1; i < buckets.length; i++) {
                if (buckets[i] != buckets[i - 1]) {
                    return false;
                }
            }
            return true;
        }

        // Try placing nums[index] in each bucket
        for (int i = 0; i < k; i++) {
            buckets[i] += nums[index];  // Place

            if (recursive(nums, k, index + 1, buckets)) {
                return true;  // Found solution!
            }

            buckets[i] -= nums[index];  // Backtrack
        }

        return false;  // No bucket worked
    }
}

🔍 Detailed Dry Run: nums = [4, 3, 2, 1], k = 2

total = 10, target = 5
Initial buckets: [0, 0]

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
BACKTRACKING EXPLORATION (Showing key paths)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

recursive(index=0, buckets=[0,0], num=4):

  Try bucket 0:
    buckets = [4, 0]

    recursive(index=1, buckets=[4,0], num=3):

      Try bucket 0:
        buckets = [7, 0]

        recursive(index=2, buckets=[7,0], num=2):

          Try bucket 0:
            buckets = [9, 0]

            recursive(index=3, buckets=[9,0], num=1):

              Try bucket 0:
                buckets = [10, 0]

                recursive(index=4, buckets=[10,0]):
                  Base case! Check buckets:
                  10 ≠ 0 → return false ✗

                Backtrack: buckets = [9, 0]

              Try bucket 1:
                buckets = [9, 1]

                recursive(index=4, buckets=[9,1]):
                  Base case! Check buckets:
                  9 ≠ 1 → return false ✗

                Backtrack: buckets = [9, 0]

              return false ✗

            Backtrack: buckets = [7, 0]

          Try bucket 1:
            buckets = [7, 2]
            (continue exploration...)
            This path also fails ✗

          return false ✗

        Backtrack: buckets = [4, 0]

      Try bucket 1:
        buckets = [4, 3]

        recursive(index=2, buckets=[4,3], num=2):

          Try bucket 0:
            buckets = [6, 3]
            (too much, will fail)
            return false ✗

          Try bucket 1:
            buckets = [4, 5]

            recursive(index=3, buckets=[4,5], num=1):

              Try bucket 0:
                buckets = [5, 5]

                recursive(index=4, buckets=[5,5]):
                  Base case! Check buckets:
                  5 == 5 → return true ✓

                return true ✓

              return true ✓

            return true ✓

          return true ✓

        return true ✓

      return true ✓

    return true ✓

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
RESULT: true ✓
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Final partition:
  Bucket 0: [4, 1] = 5
  Bucket 1: [3, 2] = 5

Both equal! Success! ✓

Why This Is Slow

Time Complexity: O(k^n)
  - For each number: k choices
  - n numbers total
  - Total: k × k × k × ... (n times) = k^n

For nums.length = 16, k = 4:
  4^16 = 4,294,967,296 (over 4 billion!)
  WAY TOO SLOW! ❌

We need optimizations! →

📊 Complexity Analysis

Time: O(k^n) - Exponential!
Space: O(n) - Recursion stack


🟡 Approach 2: Optimized Backtracking (Add Optimizations One by One)

🎯 Optimization 1: Check Target Instead of Final Comparison

PROBLEM with Approach 1:
  We check if buckets are equal ONLY at the end
  But we can detect violations EARLIER!

OPTIMIZATION:
  Instead of checking at the end,
  ensure buckets NEVER exceed target!

target = total / k

Before placing:
  if buckets[i] + nums[index] > target:
    skip this bucket!  // Will exceed target

This prunes MANY invalid paths early! ✓

🎯 Optimization 2: Sort in Descending Order

WHY sort descending?
  Place LARGE numbers first!

  Large numbers are more CONSTRAINED
  → Fewer valid placements
  → Fails faster if impossible
  → Prunes search space earlier

Example:
  nums = [1, 1, 1, 1, 5], target = 6

  If we place 1's first:
    Many ways to place them
    Explore lots of paths
    Eventually find 5 doesn't fit

  If we place 5 first:
    Only a few ways to place it
    Find failure immediately!

Sort descending = explore hard choices first! ✓

🎯 Optimization 3: Skip Equivalent Empty Buckets

INSIGHT:
  If current number doesn't work in an EMPTY bucket,
  it won't work in ANY other empty bucket!

  Why? All empty buckets are EQUIVALENT!

Example:
  buckets = [3, 5, 0, 0]
  num = 4

  Try bucket 2 (empty):
    Fails! (suppose it leads to dead end)

  Should we try bucket 3 (also empty)?
    NO! It's the same situation!
    Skip all remaining empty buckets!

if (buckets[i] == 0) break;

This dramatically reduces branches! ✓

📝 Implementation (With All Optimizations)

class Solution {
    public boolean canPartitionKSubsets(int[] nums, int k) {
        int total = 0;
        for (int num : nums) {
            total += num;
        }

        if (total % k != 0) return false;

        int target = total / k;

        // Optimization 2: Sort descending
        Arrays.sort(nums);
        reverse(nums);

        // Early check: if largest > target, impossible
        if (nums[0] > target) return false;

        int[] buckets = new int[k];
        return backtrack(0, buckets, nums, target);
    }

    private boolean backtrack(int index, int[] buckets, int[] nums, int target) {
        // Base case: all numbers placed
        if (index == nums.length) {
            // Optimization 1: No need to check!
            // If we reach here, all buckets must equal target
            // (we never allowed them to exceed it)
            return true;
        }

        int num = nums[index];

        for (int i = 0; i < buckets.length; i++) {
            // Optimization 1: Check target constraint
            if (buckets[i] + num > target) {
                continue;  // Skip this bucket
            }

            buckets[i] += num;  // Place

            if (backtrack(index + 1, buckets, nums, target)) {
                return true;
            }

            buckets[i] -= num;  // Backtrack

            // Optimization 3: Skip equivalent empty buckets
            if (buckets[i] == 0) {
                break;  // Don't try other empty buckets
            }
        }

        return false;
    }

    private void reverse(int[] arr) {
        int i = 0, j = arr.length - 1;
        while (i < j) {
            int temp = arr[i];
            arr[i] = arr[j];
            arr[j] = temp;
            i++;
            j--;
        }
    }
}

Understanding the Optimizations in Action

nums = [5, 4, 3, 2, 1], k = 2, target = 7

WITHOUT optimizations:
  Try all 2^5 = 32 combinations
  Check at end if equal
  Many wasted paths

WITH Optimization 1 (target check):
  Prune paths where bucket > 7
  Fewer paths to explore

WITH Optimization 2 (sort descending):
  Place 5 first (most constrained)
  [5,_] [_,_] or [_,_] [5,_]
  Quickly see 5+4 = 9 > 7
  Prune early!

WITH Optimization 3 (skip empty):
  If 5 fails in first empty bucket,
  don't try second empty bucket!
  Save half the work!

Combined: MUCH faster! ✓

Still Exponential, But Much Better

Time Complexity: Still O(k^n) worst case
  But pruning reduces practical runtime significantly!

Space Complexity: O(n) recursion stack

For most test cases: passes!
But for some hard cases: still times out

Can we do better? → MEMOIZATION!

🟢 Approach 3: Memoization with Bitmask - Building Intuition from Scratch

🤔 The Memoization Challenge

To memoize, we need to cache:
  memo[state] = result

But what is our STATE?
  (index, buckets[])

PROBLEM:
  buckets[] is an ARRAY!

  buckets = [3, 5, 2, 0]
  buckets = [5, 3, 0, 2]

  These are DIFFERENT arrays
  But represent SAME state! (just reordered)

  Can't use array as memo key! ✗

We need a BETTER state representation! 🔑

💡 The KEY Insight - Change Perspective!

CURRENT APPROACH (Number-by-Number):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
  "For each number, pick a bucket"
  State: (index, buckets[])
  Problem: buckets[] hard to memoize

NEW APPROACH (Bucket-by-Bucket):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
  "Fill buckets ONE BY ONE"
  Track which numbers are USED (not which bucket they're in)

  State: (which numbers used, current bucket's sum)

  Use BITMASK for "which numbers used"!
  → Single integer!
  → Easy to memoize! ✓

This is the TRANSFORMATION we need! 🎯

🧒 Understanding Bitmask - Like You're 5 Years Old

What is a Bitmask? (Start from Scratch)

Imagine you have 4 toys: Car, Ball, Doll, Train
Toys:   Car  Ball  Doll  Train
Index:   0    1     2     3

Question: "Which toys did you play with today?"

NORMAL way to track (array of booleans):
  played = [true, false, true, false]
  → Played with Car and Doll

BITMASK way (single number in binary):
  mask = 0101 (in binary)
         ||||
         |||└─ Toy 0 (Car): 1 = YES, played with it
         ||└── Toy 1 (Ball): 0 = NO, didn't play
         |└─── Toy 2 (Doll): 1 = YES, played with it
         └──── Toy 3 (Train): 0 = NO, didn't play

Same information, but stored in ONE number! ✨

In decimal: 0101 (binary) = 5 (decimal)
So mask = 5 means "played with Car and Doll"

How Does This Help Us?

For our problem:
  nums = [4, 3, 2, 1]

  Instead of tracking "which bucket is each number in?"
  We track "which numbers have been used?"

  mask = 0000 (binary) = 0 → none used yet
  mask = 0001 (binary) = 1 → nums[0]=4 used
  mask = 0011 (binary) = 3 → nums[0]=4, nums[1]=3 used
  mask = 1011 (binary) = 11 → nums[0]=4, nums[1]=3, nums[3]=1 used
  mask = 1111 (binary) = 15 → all four used!

Each mask is just ONE NUMBER!
Perfect for memoization! ✓

Building Intuition: How to Read a Bitmask

Let's practice reading bitmasks:

nums = [4, 3, 2, 1]
index:  0  1  2  3

mask = 5 in decimal = 0101 in binary
Read right to left:
  Position 0: 1 → nums[0]=4 is USED ✓
  Position 1: 0 → nums[1]=3 is NOT used
  Position 2: 1 → nums[2]=2 is USED ✓
  Position 3: 0 → nums[3]=1 is NOT used

So mask=5 means: Used numbers {4, 2}

Another example:
mask = 12 in decimal = 1100 in binary
  Position 0: 0 → nums[0]=4 NOT used
  Position 1: 0 → nums[1]=3 NOT used
  Position 2: 1 → nums[2]=2 USED ✓
  Position 3: 1 → nums[3]=1 USED ✓

So mask=12 means: Used numbers {2, 1}

Simple pattern:
  Each bit position corresponds to a number!
  1 = used, 0 = not used

The TWO Key Bitmask Operations

OPERATION 1: Check if number i is used
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Example: Is nums[1]=3 used in mask=5?

mask = 5 = 0101
           ||||
Want to check position 1 (second from right)

Step 1: Create a "probe" with only position 1 set
  (1 << 1) = 0010
            ||||
            ||└1─ Only position 1 is 1

Step 2: AND with mask
  0101  (mask=5)
& 0010  (1 << 1)
──────
  0000  (result = 0)

Result = 0 means position 1 is NOT set
So nums[1] is NOT used! ✓

Check nums[0]:
  (1 << 0) = 0001
  0101 & 0001 = 0001 (non-zero!)
  So nums[0] IS used! ✓

Formula: (mask & (1 << i)) != 0
         If NOT zero → number i is USED
         If zero → number i is NOT used

OPERATION 2: Mark number i as used
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Example: Mark nums[1]=3 as used in mask=5

mask = 5 = 0101
Want to turn ON position 1

Step 1: Create mask with only position 1 set
  (1 << 1) = 0010

Step 2: OR with current mask
  0101  (mask=5)
| 0010  (1 << 1)
──────
  0111  (result = 7)

New mask = 7 = 0111
Now position 1 is ON!
Used numbers: {4, 3, 2}

Formula: mask | (1 << i)
         Turns ON bit i (marks as used)

Visual summary:
  Before: 0101 (nums 4,2 used)
  After:  0111 (nums 4,3,2 used)

Putting It Together: Our State

OLD STATE (hard to memoize):
  (index, buckets[])
  buckets[] = which bucket has what
  Example: [3, 5, 2] means bucket0=3, bucket1=5, bucket2=2

NEW STATE (easy to memoize):
  (mask, currentSum)

  mask = which numbers are used (in completed buckets)
  currentSum = sum of bucket we're currently filling

Example:
  nums = [4, 3, 2, 1], target = 5

  State: (mask=0001, currentSum=4)
  Meaning:
    - nums[0]=4 is used (in some completed bucket)
    - We're filling current bucket
    - Current bucket has sum = 4 so far
    - Need 1 more to reach target 5

  State: (mask=1001, currentSum=0)
  Meaning:
    - nums[0]=4 and nums[3]=1 are used
    - They completed a bucket! [4,1]=5 ✓
    - Now currentSum=0 means starting a NEW bucket
    - Need to pick from unused: {3, 2}

The mask + currentSum perfectly captures our state! ✓

📐 Function Definition (Bitmask Approach)

Function Signature:

boolean canPartition(int mask, int currentSum, int target, int[] nums, Boolean[] memo)

What does this function compute? - Input mask: Bitmask showing which numbers are already used - Input currentSum: Sum of current bucket being filled - Input target: Target sum each bucket should reach - Input nums: The array of numbers - Input memo: Memoization array indexed by mask - Output: true if we can partition remaining numbers into equal buckets

The Recursive Structure:

Base case:
  if mask == (1 << n) - 1:  // All bits set = all numbers used
    return true  // All buckets complete!

If currentSum == target:  // Current bucket full!
  Bucket complete! Start new bucket
  return canPartition(mask, 0, target, nums, memo)

Check memo:
  if memo[mask] != null:
    return memo[mask]

Try each unused number:
  for i = 0 to n-1:
    if (mask & (1 << i)) == 0:  // number i NOT used
      if currentSum + nums[i] <= target:
        newMask = mask | (1 << i)  // mark as used
        if canPartition(newMask, currentSum + nums[i], ...):
          memo[mask] = true
          return true

memo[mask] = false
return false

Why this works:

We fill buckets ONE BY ONE:

  Bucket 1:
    Start with currentSum = 0
    Add numbers until currentSum = target
    Mark those numbers as used in mask

  Bucket 2:
    Reset currentSum = 0
    Add unused numbers until currentSum = target
    Mark those as used too

  Continue until all numbers used (mask = 1111...1)

The bitmask tracks which numbers are in completed buckets!
We don't care WHICH bucket, just if they're used! ✓

💡 Intuition with Visual Example

Fill buckets sequentially:

target = 5
nums = [4, 3, 2, 1]
index:  0  1  2  3

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
FILLING BUCKET 1
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

State: (mask=0000, currentSum=0)
  Used: none, Current bucket: empty
  Try nums[0]=4:

State: (mask=0001, currentSum=4)
  Used: {4}, Current bucket: [4]
  Need 1 more to reach 5
  Try nums[3]=1:

State: (mask=1001, currentSum=5)
  currentSum == target!
  Bucket 1 complete: [4,1] = 5 ✓

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
FILLING BUCKET 2
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

State: (mask=1001, currentSum=0)
  Used: {4,1} (in bucket 1), Current bucket: empty
  Available: {3,2}
  Try nums[1]=3:

State: (mask=1011, currentSum=3)
  Used: {4,3,1}, Current bucket: [3]
  Need 2 more
  Try nums[2]=2:

State: (mask=1111, currentSum=5)
  currentSum == target!
  Bucket 2 complete: [3,2] = 5 ✓

State: (mask=1111, ...)
  mask == 1111 (all used!)
  Base case: return true! ✓

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Final partition:
  Bucket 1: [4,1] = 5
  Bucket 2: [3,2] = 5

The bitmask elegantly tracked which numbers were used!

📝 Implementation

class Solution {
    public boolean canPartitionKSubsets(int[] nums, int k) {
        int total = 0;
        for (int num : nums) {
            total += num;
        }

        if (total % k != 0) return false;

        int target = total / k;
        int n = nums.length;

        for (int num : nums) {
            if (num > target) return false;
        }

        // Optimization: sort descending
        Arrays.sort(nums);
        reverse(nums);

        // memo[mask] = can we partition numbers NOT in mask?
        Boolean[] memo = new Boolean[1 << n];

        return backtrackBitmask(0, 0, target, nums, memo);
    }

    private boolean backtrackBitmask(int mask, int currentSum, int target, 
                                     int[] nums, Boolean[] memo) {
        int n = nums.length;

        // Base case: all numbers used
        if (mask == (1 << n) - 1) {
            return true;
        }

        // Current bucket complete! Start new bucket
        if (currentSum == target) {
            // Check memo before recursing
            if (memo[mask] != null) {
                return memo[mask];
            }

            boolean result = backtrackBitmask(mask, 0, target, nums, memo);
            memo[mask] = result;
            return result;
        }

        // Check memo for current state
        if (memo[mask] != null) {
            return memo[mask];
        }

        // Try each unused number
        for (int i = 0; i < n; i++) {
            // Is number i already used?
            if ((mask & (1 << i)) != 0) continue;

            // Would this exceed target?
            if (currentSum + nums[i] > target) continue;

            // Try using this number
            int newMask = mask | (1 << i);
            if (backtrackBitmask(newMask, currentSum + nums[i], target, nums, memo)) {
                memo[mask] = true;
                return true;
            }
        }

        memo[mask] = false;
        return false;
    }

    private void reverse(int[] arr) {
        int i = 0, j = arr.length - 1;
        while (i < j) {
            int temp = arr[i];
            arr[i] = arr[j];
            arr[j] = temp;
            i++;
            j--;
        }
    }
}

🔍 Detailed Dry Run: nums = [4,3,2,1], k = 2, target = 5

n = 4
Initial: mask = 0000 (decimal 0), currentSum = 0

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
BITMASK BACKTRACKING
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

backtrackBitmask(mask=0, currentSum=0):
  All available: [4,3,2,1]

  Try i=0 (num=4):
    Check if used: (0 & 1) = 0 ✓ not used
    0 + 4 = 4 <= 5 ✓
    newMask = 0 | 1 = 1 (binary: 0001)

    backtrackBitmask(mask=1, currentSum=4):
      Used: {4}, Available: {3,2,1}

      Try i=1 (num=3):
        Check: (1 & 2) = 0 ✓ not used
        4 + 3 = 7 > 5 ✗

      Try i=2 (num=2):
        Check: (1 & 4) = 0 ✓ not used
        4 + 2 = 6 > 5 ✗

      Try i=3 (num=1):
        Check: (1 & 8) = 0 ✓ not used
        4 + 1 = 5 <= 5 ✓
        newMask = 1 | 8 = 9 (binary: 1001)

        backtrackBitmask(mask=9, currentSum=5):
          currentSum == target!
          Bucket 1 complete: [4,1] = 5 ✓
          Check memo[9]? null

          Start new bucket:
          backtrackBitmask(mask=9, currentSum=0):
            Used: {4,1}, Available: {3,2}

            Try i=1 (num=3):
              Check: (9 & 2) = 0 ✓ not used
              0 + 3 = 3 <= 5 ✓
              newMask = 9 | 2 = 11 (binary: 1011)

              backtrackBitmask(mask=11, currentSum=3):
                Used: {4,3,1}, Available: {2}

                Try i=2 (num=2):
                  Check: (11 & 4) = 0 ✓ not used
                  3 + 2 = 5 <= 5 ✓
                  newMask = 11 | 4 = 15 (binary: 1111)

                  backtrackBitmask(mask=15, currentSum=5):
                    currentSum == target!
                    Check memo[15]? null

                    backtrackBitmask(mask=15, currentSum=0):
                      mask == 15 == (1<<4)-1
                      Base case: all used!
                      return true ✓

                    memo[15] = true
                    return true ✓

                  return true ✓

                memo[11] = true
                return true ✓

              memo[9] = true
              return true ✓

          return true ✓

        return true ✓

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
RESULT: true ✓
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Partition:
  Bucket 1: {4,1} = 5 (mask transition: 0→1→9)
  Bucket 2: {3,2} = 5 (mask transition: 9→11→15)

Memo entries saved:
  memo[9] = true (after bucket 1, can complete rest)
  memo[11] = true (after adding 3, can add 2)
  memo[15] = true (all numbers used successfully)

Why Memoization Works Here

KEY INSIGHT:
  Once we've computed memo[mask],
  the answer is ALWAYS the same!

  Doesn't matter:
    - Which bucket we're filling (1st, 2nd, 3rd, ...)
    - What order we used numbers
    - How we reached this mask

  Only matters:
    - Which numbers are used (mask)
    - Where we are in filling current bucket (implicit)

  We memoize when currentSum == 0
  (right before starting a new bucket)

  This captures the state perfectly! ✓

Example:
  If memo[9] = true, it means:
    "Given that {4,1} are used,
     we CAN complete the remaining buckets"

  Next time we see mask=9, instant answer! ✓

📊 Complexity Analysis

Time Complexity: O(2^n × n)

States: 2^n possible masks
For each state: try n numbers
Total: O(2^n × n)

With memoization:
  Each state computed ONCE

For n=16:
  2^16 × 16 = 1,048,576 operations
  Feasible! ✓

Space Complexity: O(2^n)

Memo array: 2^n entries
Recursion stack: O(n)
Total: O(2^n)


🔵 Approach 4: Bottom-Up DP - Deriving from Top-Down

🎯 Can We Do Bottom-Up?

Our top-down uses:
  memo[mask] = can we partition numbers NOT in mask?

Can we build this bottom-up?

Challenge:
  The recursion fills buckets ONE BY ONE
  Order matters: currentSum resets when bucket fills

  This makes bottom-up TRICKY!

But we can still do it! Let's derive step by step...

📐 Understanding the State Transition

TOP-DOWN (what we have):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

memo[mask] = can we complete remaining buckets?
             starting with currentSum = 0

When currentSum reaches target:
  → Bucket done
  → Reset to 0
  → Continue with same mask

BOTTOM-UP (what we'll build):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

dp[mask] = can we partition numbers in mask into complete buckets
           (possibly with incomplete bucket at currentSum)

We also need:
  remainder[mask] = currentSum when we have mask

Build from smaller masks to larger masks!

🧒 Deriving Base Cases (ELI5)

BASE CASE in TOP-DOWN:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

if (mask == (1 << n) - 1):  // All numbers used
  return true

This means: "If all numbers are used, we're done!"

BOTTOM-UP TRANSLATION:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Start from mask = 0 (no numbers used)
Build up to mask = (1 << n) - 1 (all used)

dp[0] = true (no numbers used yet, valid starting point)
remainder[0] = 0 (no current sum)

Goal: compute dp[(1<<n)-1]

If dp[(1<<n)-1] = true → answer is true!

📝 Bottom-Up Implementation

class Solution {
    public boolean canPartitionKSubsets(int[] nums, int k) {
        int total = 0;
        for (int num : nums) {
            total += num;
        }

        if (total % k != 0) return false;

        int target = total / k;
        int n = nums.length;

        for (int num : nums) {
            if (num > target) return false;
        }

        // dp[mask] = can we partition numbers in mask validly?
        // remainder[mask] = current bucket's sum for this mask
        boolean[] dp = new boolean[1 << n];
        int[] remainder = new int[1 << n];

        // Base case
        dp[0] = true;
        remainder[0] = 0;

        // Build from smaller masks to larger
        for (int mask = 0; mask < (1 << n); mask++) {
            if (!dp[mask]) continue;  // Skip invalid states

            // Try adding each unused number
            for (int i = 0; i < n; i++) {
                // Is number i already used in mask?
                if ((mask & (1 << i)) != 0) continue;

                int newMask = mask | (1 << i);
                int newRemainder = (remainder[mask] + nums[i]) % target;

                // Can we add this number?
                if (remainder[mask] + nums[i] <= target) {
                    dp[newMask] = true;
                    remainder[newMask] = newRemainder;
                }
            }
        }

        return dp[(1 << n) - 1];
    }
}

🔍 Detailed Dry Run: nums = [4,2], k = 2, target = 3

n = 2
target = 3

Initial:
  dp[0] = true
  remainder[0] = 0

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
BUILD DP TABLE
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Process mask = 0 (binary: 00):
  dp[0] = true, remainder[0] = 0
  Available: {4, 2}

  Try i=0 (num=4):
    Not used: ✓
    remainder[0] + 4 = 0 + 4 = 4
    4 > target (3)? YES ✗
    Can't add!

  Try i=1 (num=2):
    Not used: ✓
    remainder[0] + 2 = 0 + 2 = 2
    2 <= target? YES ✓
    newMask = 0 | 2 = 2 (binary: 10)
    newRemainder = (0 + 2) % 3 = 2
    dp[2] = true
    remainder[2] = 2

Process mask = 1 (binary: 01):
  dp[1] = false → skip

Process mask = 2 (binary: 10):
  dp[2] = true, remainder[2] = 2
  Used: {2}, Available: {4}

  Try i=0 (num=4):
    Not used: ✓
    remainder[2] + 4 = 2 + 4 = 6
    6 > target? YES ✗
    Can't add!

  Try i=1 (num=2):
    Already used ✗

Process mask = 3 (binary: 11):
  dp[3] = false

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
RESULT: dp[3] = false ✗
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Can't partition [4,2] into 2 groups of sum 3
(4 > 3, impossible!)

Why Bottom-Up is Harder Here

Bottom-up DP works BUT:
  - Need to track remainder carefully
  - Order of processing matters
  - Less intuitive than top-down

Top-down is MUCH cleaner for this problem!
  - Natural recursion matches bucket-filling
  - currentSum tracked naturally
  - Memoization straightforward

RECOMMENDATION:
  Use top-down (Approach 3) for this problem! ✓
  Bottom-up possible but unnecessarily complex

📊 Complexity

Time: O(2^n × n)
Space: O(2^n)
Clarity: Lower than top-down


📊 Complete Approach Comparison

┌────────────────────────────────────────────────────────────────┐
│     PARTITION K EQUAL SUBSETS - APPROACH COMPARISON            │
├──────────────────┬──────────┬──────────┬──────────┬────────────┤
│ Approach         │ Time     │ Space    │ Clarity  │ Interview  │
├──────────────────┼──────────┼──────────┼──────────┼────────────┤
│ Simple Backtrack │ O(k^n)   │ O(n)     │ High     │ Start      │
│ Optimized BT     │ O(k^n)*  │ O(n)     │ High     │ Good       │
│ Bitmask Memo     │ O(2^n×n) │ O(2^n)   │ Medium   │ Best ✓     │
│ Bottom-Up DP     │ O(2^n×n) │ O(2^n)   │ Low      │ Optional   │
└──────────────────┴──────────┴──────────┴──────────┴────────────┘

* Much faster with pruning in practice
Constraint n <= 16: Bitmask DP perfect!

💻 Complete Working Code

import java.util.Arrays;

class Solution {
  public boolean canPartitionKSubsets(int[] nums, int k) {
    return bitmaskDP(nums, k);
  }

  // Approach 3: Bitmask DP (Top-Down) - O(2^n × n), RECOMMENDED ✓
  private boolean bitmaskDP(int[] nums, int k) {
    int total = 0;
    for (int num : nums) {
      total += num;
    }

    if (total % k != 0)
      return false;

    int target = total / k;
    int n = nums.length;

    for (int num : nums) {
      if (num > target)
        return false;
    }

    Arrays.sort(nums);
    reverse(nums);

    Boolean[] memo = new Boolean[1 << n];
    return backtrackBitmask(0, 0, target, nums, memo);
  }

  private boolean backtrackBitmask(int mask, int currentSum, int target, int[] nums, Boolean[] memo) {
    int n = nums.length;

    if (mask == (1 << n) - 1) {
      return true;
    }

    if (currentSum == target) {
      if (memo[mask] != null) {
        return memo[mask];
      }

      boolean result = backtrackBitmask(mask, 0, target, nums, memo);
      memo[mask] = result;
      return result;
    }

    if (memo[mask] != null) {
      return memo[mask];
    }

    for (int i = 0; i < n; i++) {
      if ((mask & (1 << i)) != 0)
        continue;
      if (currentSum + nums[i] > target)
        continue;

      int newMask = mask | (1 << i);
      if (backtrackBitmask(newMask, currentSum + nums[i], target, nums, memo)) {
        memo[mask] = true;
        return true;
      }
    }

    memo[mask] = false;
    return false;
  }

  // Approach 4: Bottom-Up DP - O(2^n × n), less intuitive
  private boolean bottomUpDP(int[] nums, int k) {
    int total = 0;
    for (int num : nums) {
      total += num;
    }

    if (total % k != 0)
      return false;

    int target = total / k;
    int n = nums.length;

    for (int num : nums) {
      if (num > target)
        return false;
    }

    boolean[] dp = new boolean[1 << n];
    int[] remainder = new int[1 << n];

    dp[0] = true;
    remainder[0] = 0;

    for (int mask = 0; mask < (1 << n); mask++) {
      if (!dp[mask])
        continue;

      for (int i = 0; i < n; i++) {
        if ((mask & (1 << i)) != 0)
          continue;

        int newMask = mask | (1 << i);

        if (remainder[mask] + nums[i] <= target) {
          dp[newMask] = true;
          remainder[newMask] = (remainder[mask] + nums[i]) % target;
        }
      }
    }

    return dp[(1 << n) - 1];
  }

  // Approach 2: Optimized Backtracking - O(k^n), practical
  private boolean optimizedBacktracking(int[] nums, int k) {
    int total = 0;
    for (int num : nums) {
      total += num;
    }

    if (total % k != 0)
      return false;

    int target = total / k;

    Arrays.sort(nums);
    reverse(nums);

    if (nums[0] > target)
      return false;

    int[] buckets = new int[k];
    return backtrackOptimized(0, buckets, nums, target);
  }

  private boolean backtrackOptimized(int index, int[] buckets, int[] nums, int target) {
    if (index == nums.length) {
      return true;
    }

    int num = nums[index];

    for (int i = 0; i < buckets.length; i++) {
      if (buckets[i] + num > target) {
        continue;
      }

      buckets[i] += num;

      if (backtrackOptimized(index + 1, buckets, nums, target)) {
        return true;
      }

      buckets[i] -= num;

      if (buckets[i] == 0)
        break;
    }

    return false;
  }

  // Approach 1: Simple Backtracking - O(k^n), what we discover first
  private boolean simpleBacktracking(int[] nums, int k) {
    int total = 0;
    for (int num : nums) {
      total += num;
    }

    if (total % k != 0)
      return false;

    int[] buckets = new int[k];
    return recursive(nums, k, 0, buckets);
  }

  private boolean recursive(int[] nums, int k, int index, int[] buckets) {
    if (index == nums.length) {
      for (int i = 1; i < buckets.length; i++) {
        if (buckets[i] != buckets[i - 1]) {
          return false;
        }
      }
      return true;
    }

    for (int i = 0; i < k; i++) {
      buckets[i] += nums[index];

      if (recursive(nums, k, index + 1, buckets)) {
        return true;
      }

      buckets[i] -= nums[index];
    }

    return false;
  }

  private void reverse(int[] arr) {
    int i = 0, j = arr.length - 1;
    while (i < j) {
      int temp = arr[i];
      arr[i] = arr[j];
      arr[j] = temp;
      i++;
      j--;
    }
  }

  public static void main(String[] args) {
    Solution s = new Solution();

    System.out.println(s.canPartitionKSubsets(new int[] { 4, 3, 2, 3, 5, 2, 1 }, 4) == true);
    System.out.println(s.canPartitionKSubsets(new int[] { 1, 2, 3, 4 }, 3) == false);
    System.out.println(s.canPartitionKSubsets(new int[] { 4, 3, 2, 1 }, 2) == true);
  }
}

🔑 Key Insights

The Complete Journey

1. Start Simple:
   Try all assignments → O(k^n)

2. Add Optimizations:
   Check target, sort, skip empties
   → Still O(k^n) but faster

3. Hit Wall:
   Can't memoize buckets[] array!

4. Transform:
   Use bitmask to track used numbers
   → Can memoize! O(2^n × n)

5. Success:
   Top-down bitmask DP optimal!
   Bottom-up possible but complex

Bitmask Power

Why bitmask is elegant:
  ✓ Compresses state into single int
  ✓ Fast bit operations
  ✓ Perfect for memoization
  ✓ Works great for n ≤ 16

This technique appears in MANY hard problems!

🎪 Memory Aid

"k=2 → DP, k>2 → Backtracking!"
"Bitmask = which numbers used (1=yes, 0=no)!"
"Check bit: (mask & (1<<i)) != 0"
"Set bit: mask | (1<<i)"
"Top-down cleaner than bottom-up!"


📝 Quick Revision Notes

🎯 Core Concept

Problem: Partition into k equal sum subsets
Check: total % k == 0, target = total/k
Simple: Backtracking O(k^n)
Optimal: Bitmask DP O(2^n × n) ✓

⚡ Bitmask Quick Reference

Check if i used: (mask & (1 << i)) != 0
Mark i as used: mask | (1 << i)

Example: nums = [4,3,2,1]
  mask = 5 = 0101 → {4,2} used
  mask = 11 = 1011 → {4,3,1} used
  mask = 15 = 1111 → all used

Implementation

import java.util.Arrays;

public class Solution {
  public boolean canPartitionKSubsets(int[] nums, int k) {
    int[] buckets = new int[k];

    // return recursive(nums, k, 0, buckets);

    // Optimal 1: sort the array in descending order so that we can have early exits
    Arrays.sort(nums);
    reverse(nums);

    // Optimal 2: currently we reach to end and checking if each bucket has same
    // value or not. Instead, lets check for early exit.
    // target is for each bucket post that any number should not be included in that
    // bucket.
    int total = Arrays.stream(nums).sum();
    if (total % k != 0 || nums[0] > total / k) {
      return false;
    }

    // return recursive_optimal(nums, k, 0, buckets, total / k);

    // step 1: consider [4,3,2,1]
    // we track memo for tracking which numbers (at which indices) are used already
    // "1111" => all numbers are used
    // "1010" => 4 and 2 are used
    // "10000" => no number used. Why "10000" instead of "0000"?
    // You cannot have a 0 sized index. So, have 1 before all 0s.
    Boolean[] memo = new Boolean[1 << nums.length];
    return topDown(nums, total / k, 0, 0, memo);
  }

  private boolean topDown(int[] nums, int k, int mask, int currentSum, Boolean[] memo) {
    // step 7: base case
    if (mask == (1 << nums.length) - 1) {
      // for example, mask = 1111 (15, all used) for [4,3,2,1]
      // 1<<4 => 10000
      // (1<<4)-1 => 10000(16)-1=15
      return true;
    }

    // step 8: if at anytime, any bucket got satisfied with
    // any combination tried earlier. Try now with other buckets and numbers
    if (currentSum == k) {
      if (memo[mask] != null) {
        return memo[mask];
      }

      memo[mask] = topDown(nums, k, mask, 0, memo);

    }

    // step 9
    if (memo[mask] != null) {
      return memo[mask];
    }

    // step 2: try each number to put in a bucket
    for (int i = 0; i < nums.length; i++) {
      // step 3: check if number at index i is used
      // mask = 0011 and if i want to use number at index i
      // 0011 & 0001 => 1 != 0
      if ((mask & (1 << i)) != 0) {
        continue; // used. Proceed for next number
      }

      // step 4: can adding this number is useful or not
      if (nums[i] + currentSum > k) {
        continue;
      }

      // step 5: update mask
      int updatedMask = mask | (1 << i);

      // step 6: try new target
      if (topDown(nums, k, updatedMask, currentSum + nums[i], memo)) {
        return memo[mask] = true;
      }
    }

    return memo[mask] = false;
  }

  private void reverse(int[] nums) {
    int start = 0;
    int end = nums.length - 1;

    while (start <= end) {
      int temp = nums[start];
      nums[start] = nums[end];
      nums[end] = temp;

      start++;
      end--;
    }
  }

  private boolean recursive_optimal(int[] nums, int k, int index, int[] buckets, int target) {
    // step 3: base case
    // check if all buckets have same value, then true.
    // optimal 5: no need to check all buckets now
    if (index == nums.length) {
      return true;
    }

    // step 1: try to keep number at index at each bucket and proceed to next index
    boolean res = false;
    for (int i = 0; i < k; i++) {
      // step 2a: try number at index in bucket i
      // Optimal 3: whether adding this number is useful or not
      // skip it if not useful
      if (buckets[i] + nums[index] > target) {
        continue;
      }

      buckets[i] = buckets[i] + nums[index];
      res = recursive_optimal(nums, k, index + 1, buckets, target);

      if (res) {
        return true;
      }

      // step 2b: remove number at index in bucket i and try next i
      buckets[i] = buckets[i] - nums[index];

      // optimal 4: if adding to empty bucket is not useful, adding to any other
      // bucket is not useful. Hence, skip all other buckets.
      if (buckets[i] == 0) {
        break;
      }
    }

    return false;
  }

  private boolean recursive(int[] nums, int k, int index, int[] buckets) {
    // step 3: base case
    // check if all buckets have same value, then true.
    if (index == nums.length) {
      for (int i = 1; i < buckets.length; i++) {
        if (buckets[i] != buckets[i - 1]) {
          return false;
        }
      }

      return true;
    }

    // step 1: try to keep number at index at each bucket and proceed to next index
    boolean res = false;
    for (int i = 0; i < k; i++) {
      // step 2a: try number at index in bucket i
      buckets[i] = buckets[i] + nums[index];
      res = res | recursive(nums, k, index + 1, buckets);

      if (res) {
        return true;
      }

      // step 2b: remove number at index in bucket i and try next i
      buckets[i] = buckets[i] - nums[index];
    }

    return false;
  }

  public static void main(String[] args) {
    Solution s = new Solution();
    System.out.println(s.canPartitionKSubsets(new int[] { 4, 3, 2, 3, 5, 2, 1 }, 4) == true);
    System.out.println(s.canPartitionKSubsets(new int[] { 1, 2, 3, 4 }, 3) == false);
  }
}

🎉 Congratulations!

You've mastered a HARD problem with: - ✅ Natural progression from simple to optimal - ✅ INTUITIVE bitmask explanation (from scratch!) - ✅ Complete optimizations with reasoning - ✅ Both top-down and bottom-up DP - ✅ All content preserved and improved!

Ready for the next challenge! 🚀✨