235. Partition to K Equal Sum Subsets
🔗 LeetCode Problem: 698. Partition to K Equal Sum Subsets
📊 Difficulty: Hard
🏷️ Topics: Dynamic Programming, Backtracking, Bit Manipulation, Memoization, Bitmask DP
Problem Statement
Given an integer array nums and an integer k, return true if it is possible to divide this array into k non-empty subsets whose sums are all equal.
Example 1:
Input: nums = [4,3,2,3,5,2,1], k = 4
Output: true
Explanation: It is possible to divide it into 4 subsets (5), (1,4), (2,3), (2,3) with equal sums.
Example 2:
Input: nums = [1,2,3,4], k = 3
Output: false
Explanation: It is impossible to divide it into 3 equal sum subsets.
Constraints:
- 1 <= k <= nums.length <= 16
- 1 <= nums[i] <= 10^4
- The frequency of each element is in the range [1, 4]
🔍 Let's Discover the Pattern Together
Start with Small Examples
Let's not jump to the solution. Let's understand what's really happening by trying small examples:
Example 1: nums = [4, 3, 2, 3, 5, 2, 1], k = 4
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Total sum = 4+3+2+3+5+2+1 = 20
Want: 4 equal subsets
If 4 subsets with equal sums:
Each subset sum = 20/4 = 5
Now the question becomes:
"Can we form 4 subsets, each with sum 5?"
Let's try:
Subset 1: {5} → sum = 5 ✓
Subset 2: {4, 1} → sum = 5 ✓
Subset 3: {3, 2} → sum = 5 ✓
Subset 4: {3, 2} → sum = 5 ✓
Yes! It's possible! ✓
Example 2: nums = [1, 2, 3, 4], k = 3
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Total sum = 1+2+3+4 = 10
Want: 3 equal subsets
Each subset sum = 10/3 = 3.33...
STOP! 🚫
Sum is not divisible by k!
Can't have equal integer sums!
Answer: false ✗
Example 3: nums = [2, 2, 2, 2, 2, 2], k = 3
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Total sum = 12
Each subset sum = 12/3 = 4
Can we form 3 subsets with sum 4?
Subset 1: {2, 2} → sum = 4 ✓
Subset 2: {2, 2} → sum = 4 ✓
Subset 3: {2, 2} → sum = 4 ✓
Yes! ✓
Notice the Pattern
Key observations:
1. First check: Is total divisible by k?
If NO → impossible!
2. Target sum = total / k
Each subset must sum to this target
3. Problem transforms to:
"Can we group all numbers into k groups,
where each group sums to target?"
4. This is like filling k buckets!
Each bucket needs to reach exactly 'target'
The Bucket-Filling Intuition
Think of it as k buckets:
Buckets: [ ] [ ] [ ] [ ] (k = 4)
Target: 5 5 5 5
Numbers to distribute: 4, 3, 2, 3, 5, 2, 1
For each number, pick a bucket!
Number 4 → Bucket 0: [4, _, _]
Number 3 → Bucket 1: [3, _, _]
Number 2 → Bucket 0: [4, 2, _] (now 6 > 5 ✗)
→ Bucket 2: [2, _, _] (try this!)
...
Try all possibilities!
This is BACKTRACKING! 🔑
💡 The AHA Moment - It's a Backtracking Problem!
Connection to Previous Problems
Problem 228: Partition Equal Subset Sum (k=2)
→ Can we make 2 equal groups?
→ Used DP (subset sum)
Problem 233: Last Stone Weight II (k=2)
→ Minimize difference between 2 groups
→ Used DP (subset sum)
THIS Problem: k can be > 2!
→ Can't use simple DP
→ Need BACKTRACKING (try all assignments)
The jump from k=2 to k>2 is SIGNIFICANT! 🔥
Why Not DP?
For k=2:
State: "Can we make sum X?"
DP: dp[sum] = true/false
Works because only 2 groups!
For k>2:
Need to track MULTIPLE group sums simultaneously
dp[sum1][sum2][sum3]...? Too many dimensions!
Better approach: Try all assignments!
→ BACKTRACKING
🔴 Approach 1: Simple Backtracking (What We Discover First)
📐 Function Definition
Function Signature:
boolean recursive(int[] nums, int k, int index, int[] buckets)
What does this function compute?
- Input index: Current number we're trying to place (nums[index])
- Input buckets: Array of k buckets, buckets[i] = current sum in bucket i
- Input nums: The array of numbers
- Input k: Number of buckets
- Output: true if we can distribute remaining numbers such that all buckets become equal
What does the recursive call represent?
- For current number nums[index], try placing it in EACH of the k buckets
- For each bucket i:
- Place nums[index] in bucket i
- Recursively try to place remaining numbers (index+1)
- If successful, return true
- Otherwise, backtrack (remove from bucket i and try next bucket)
- If no bucket works, return false
The Recursive Structure:
Base case:
If index == n (all numbers placed):
Check if all buckets have the SAME sum
return true if equal, false otherwise
For current number nums[index]:
Try each bucket i from 0 to k-1:
buckets[i] += nums[index] // Place in bucket i
if recursive(index+1, buckets): // Try remaining numbers
return true // Found solution!
buckets[i] -= nums[index] // Backtrack
return false // No bucket worked
Why this works:
We systematically try all possible assignments:
- Each number can go into any of k buckets
- We explore all k^n combinations
- When we've placed all numbers, check if buckets are equal
- Backtrack when necessary
This is exhaustive search!
💡 Intuition
Think of placing items one by one:
nums = [4, 3, 2, 1], k = 2
Start with empty buckets: [0, 0]
Place 4:
Try bucket 0: [4, 0]
Place 3:
Try bucket 0: [7, 0]
Place 2:
Try bucket 0: [9, 0]
Place 1:
Try bucket 0: [10, 0]
All placed! Check: 10 ≠ 0 ✗
Try bucket 1: [9, 1]
All placed! Check: 9 ≠ 1 ✗
Backtrack!
Try bucket 1: [9, 2]
... (continue)
Try bucket 1: [4, 3]
... (continue)
Tree of possibilities!
Explore until we find a solution or exhaust all options.
📝 Implementation (Simple Version - What We Discover First)
class Solution {
public boolean canPartitionKSubsets(int[] nums, int k) {
int total = 0;
for (int num : nums) {
total += num;
}
// Early check: divisible?
if (total % k != 0) return false;
int[] buckets = new int[k];
return recursive(nums, k, 0, buckets);
}
private boolean recursive(int[] nums, int k, int index, int[] buckets) {
// Base case: all numbers placed
if (index == nums.length) {
// Check if all buckets have same sum
for (int i = 1; i < buckets.length; i++) {
if (buckets[i] != buckets[i - 1]) {
return false;
}
}
return true;
}
// Try placing nums[index] in each bucket
for (int i = 0; i < k; i++) {
buckets[i] += nums[index]; // Place
if (recursive(nums, k, index + 1, buckets)) {
return true; // Found solution!
}
buckets[i] -= nums[index]; // Backtrack
}
return false; // No bucket worked
}
}
🔍 Detailed Dry Run: nums = [4, 3, 2, 1], k = 2
total = 10, target = 5
Initial buckets: [0, 0]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
BACKTRACKING EXPLORATION (Showing key paths)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
recursive(index=0, buckets=[0,0], num=4):
Try bucket 0:
buckets = [4, 0]
recursive(index=1, buckets=[4,0], num=3):
Try bucket 0:
buckets = [7, 0]
recursive(index=2, buckets=[7,0], num=2):
Try bucket 0:
buckets = [9, 0]
recursive(index=3, buckets=[9,0], num=1):
Try bucket 0:
buckets = [10, 0]
recursive(index=4, buckets=[10,0]):
Base case! Check buckets:
10 ≠ 0 → return false ✗
Backtrack: buckets = [9, 0]
Try bucket 1:
buckets = [9, 1]
recursive(index=4, buckets=[9,1]):
Base case! Check buckets:
9 ≠ 1 → return false ✗
Backtrack: buckets = [9, 0]
return false ✗
Backtrack: buckets = [7, 0]
Try bucket 1:
buckets = [7, 2]
(continue exploration...)
This path also fails ✗
return false ✗
Backtrack: buckets = [4, 0]
Try bucket 1:
buckets = [4, 3]
recursive(index=2, buckets=[4,3], num=2):
Try bucket 0:
buckets = [6, 3]
(too much, will fail)
return false ✗
Try bucket 1:
buckets = [4, 5]
recursive(index=3, buckets=[4,5], num=1):
Try bucket 0:
buckets = [5, 5]
recursive(index=4, buckets=[5,5]):
Base case! Check buckets:
5 == 5 → return true ✓
return true ✓
return true ✓
return true ✓
return true ✓
return true ✓
return true ✓
return true ✓
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
RESULT: true ✓
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Final partition:
Bucket 0: [4, 1] = 5
Bucket 1: [3, 2] = 5
Both equal! Success! ✓
Why This Is Slow
Time Complexity: O(k^n)
- For each number: k choices
- n numbers total
- Total: k × k × k × ... (n times) = k^n
For nums.length = 16, k = 4:
4^16 = 4,294,967,296 (over 4 billion!)
WAY TOO SLOW! ❌
We need optimizations! →
📊 Complexity Analysis
Time: O(k^n) - Exponential!
Space: O(n) - Recursion stack
🟡 Approach 2: Optimized Backtracking (Add Optimizations One by One)
🎯 Optimization 1: Check Target Instead of Final Comparison
PROBLEM with Approach 1:
We check if buckets are equal ONLY at the end
But we can detect violations EARLIER!
OPTIMIZATION:
Instead of checking at the end,
ensure buckets NEVER exceed target!
target = total / k
Before placing:
if buckets[i] + nums[index] > target:
skip this bucket! // Will exceed target
This prunes MANY invalid paths early! ✓
🎯 Optimization 2: Sort in Descending Order
WHY sort descending?
Place LARGE numbers first!
Large numbers are more CONSTRAINED
→ Fewer valid placements
→ Fails faster if impossible
→ Prunes search space earlier
Example:
nums = [1, 1, 1, 1, 5], target = 6
If we place 1's first:
Many ways to place them
Explore lots of paths
Eventually find 5 doesn't fit
If we place 5 first:
Only a few ways to place it
Find failure immediately!
Sort descending = explore hard choices first! ✓
🎯 Optimization 3: Skip Equivalent Empty Buckets
INSIGHT:
If current number doesn't work in an EMPTY bucket,
it won't work in ANY other empty bucket!
Why? All empty buckets are EQUIVALENT!
Example:
buckets = [3, 5, 0, 0]
num = 4
Try bucket 2 (empty):
Fails! (suppose it leads to dead end)
Should we try bucket 3 (also empty)?
NO! It's the same situation!
Skip all remaining empty buckets!
if (buckets[i] == 0) break;
This dramatically reduces branches! ✓
📝 Implementation (With All Optimizations)
class Solution {
public boolean canPartitionKSubsets(int[] nums, int k) {
int total = 0;
for (int num : nums) {
total += num;
}
if (total % k != 0) return false;
int target = total / k;
// Optimization 2: Sort descending
Arrays.sort(nums);
reverse(nums);
// Early check: if largest > target, impossible
if (nums[0] > target) return false;
int[] buckets = new int[k];
return backtrack(0, buckets, nums, target);
}
private boolean backtrack(int index, int[] buckets, int[] nums, int target) {
// Base case: all numbers placed
if (index == nums.length) {
// Optimization 1: No need to check!
// If we reach here, all buckets must equal target
// (we never allowed them to exceed it)
return true;
}
int num = nums[index];
for (int i = 0; i < buckets.length; i++) {
// Optimization 1: Check target constraint
if (buckets[i] + num > target) {
continue; // Skip this bucket
}
buckets[i] += num; // Place
if (backtrack(index + 1, buckets, nums, target)) {
return true;
}
buckets[i] -= num; // Backtrack
// Optimization 3: Skip equivalent empty buckets
if (buckets[i] == 0) {
break; // Don't try other empty buckets
}
}
return false;
}
private void reverse(int[] arr) {
int i = 0, j = arr.length - 1;
while (i < j) {
int temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
i++;
j--;
}
}
}
Understanding the Optimizations in Action
nums = [5, 4, 3, 2, 1], k = 2, target = 7
WITHOUT optimizations:
Try all 2^5 = 32 combinations
Check at end if equal
Many wasted paths
WITH Optimization 1 (target check):
Prune paths where bucket > 7
Fewer paths to explore
WITH Optimization 2 (sort descending):
Place 5 first (most constrained)
[5,_] [_,_] or [_,_] [5,_]
Quickly see 5+4 = 9 > 7
Prune early!
WITH Optimization 3 (skip empty):
If 5 fails in first empty bucket,
don't try second empty bucket!
Save half the work!
Combined: MUCH faster! ✓
Still Exponential, But Much Better
Time Complexity: Still O(k^n) worst case
But pruning reduces practical runtime significantly!
Space Complexity: O(n) recursion stack
For most test cases: passes!
But for some hard cases: still times out
Can we do better? → MEMOIZATION!
🟢 Approach 3: Memoization with Bitmask - Building Intuition from Scratch
🤔 The Memoization Challenge
To memoize, we need to cache:
memo[state] = result
But what is our STATE?
(index, buckets[])
PROBLEM:
buckets[] is an ARRAY!
buckets = [3, 5, 2, 0]
buckets = [5, 3, 0, 2]
These are DIFFERENT arrays
But represent SAME state! (just reordered)
Can't use array as memo key! ✗
We need a BETTER state representation! 🔑
💡 The KEY Insight - Change Perspective!
CURRENT APPROACH (Number-by-Number):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
"For each number, pick a bucket"
State: (index, buckets[])
Problem: buckets[] hard to memoize
NEW APPROACH (Bucket-by-Bucket):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
"Fill buckets ONE BY ONE"
Track which numbers are USED (not which bucket they're in)
State: (which numbers used, current bucket's sum)
Use BITMASK for "which numbers used"!
→ Single integer!
→ Easy to memoize! ✓
This is the TRANSFORMATION we need! 🎯
🧒 Understanding Bitmask - Like You're 5 Years Old
What is a Bitmask? (Start from Scratch)
Imagine you have 4 toys: Car, Ball, Doll, Train
Toys: Car Ball Doll Train
Index: 0 1 2 3
Question: "Which toys did you play with today?"
NORMAL way to track (array of booleans):
played = [true, false, true, false]
→ Played with Car and Doll
BITMASK way (single number in binary):
mask = 0101 (in binary)
||||
|||└─ Toy 0 (Car): 1 = YES, played with it
||└── Toy 1 (Ball): 0 = NO, didn't play
|└─── Toy 2 (Doll): 1 = YES, played with it
└──── Toy 3 (Train): 0 = NO, didn't play
Same information, but stored in ONE number! ✨
In decimal: 0101 (binary) = 5 (decimal)
So mask = 5 means "played with Car and Doll"
How Does This Help Us?
For our problem:
nums = [4, 3, 2, 1]
Instead of tracking "which bucket is each number in?"
We track "which numbers have been used?"
mask = 0000 (binary) = 0 → none used yet
mask = 0001 (binary) = 1 → nums[0]=4 used
mask = 0011 (binary) = 3 → nums[0]=4, nums[1]=3 used
mask = 1011 (binary) = 11 → nums[0]=4, nums[1]=3, nums[3]=1 used
mask = 1111 (binary) = 15 → all four used!
Each mask is just ONE NUMBER!
Perfect for memoization! ✓
Building Intuition: How to Read a Bitmask
Let's practice reading bitmasks:
nums = [4, 3, 2, 1]
index: 0 1 2 3
mask = 5 in decimal = 0101 in binary
Read right to left:
Position 0: 1 → nums[0]=4 is USED ✓
Position 1: 0 → nums[1]=3 is NOT used
Position 2: 1 → nums[2]=2 is USED ✓
Position 3: 0 → nums[3]=1 is NOT used
So mask=5 means: Used numbers {4, 2}
Another example:
mask = 12 in decimal = 1100 in binary
Position 0: 0 → nums[0]=4 NOT used
Position 1: 0 → nums[1]=3 NOT used
Position 2: 1 → nums[2]=2 USED ✓
Position 3: 1 → nums[3]=1 USED ✓
So mask=12 means: Used numbers {2, 1}
Simple pattern:
Each bit position corresponds to a number!
1 = used, 0 = not used
The TWO Key Bitmask Operations
OPERATION 1: Check if number i is used
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Example: Is nums[1]=3 used in mask=5?
mask = 5 = 0101
||||
Want to check position 1 (second from right)
Step 1: Create a "probe" with only position 1 set
(1 << 1) = 0010
||||
||└1─ Only position 1 is 1
Step 2: AND with mask
0101 (mask=5)
& 0010 (1 << 1)
──────
0000 (result = 0)
Result = 0 means position 1 is NOT set
So nums[1] is NOT used! ✓
Check nums[0]:
(1 << 0) = 0001
0101 & 0001 = 0001 (non-zero!)
So nums[0] IS used! ✓
Formula: (mask & (1 << i)) != 0
If NOT zero → number i is USED
If zero → number i is NOT used
OPERATION 2: Mark number i as used
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Example: Mark nums[1]=3 as used in mask=5
mask = 5 = 0101
Want to turn ON position 1
Step 1: Create mask with only position 1 set
(1 << 1) = 0010
Step 2: OR with current mask
0101 (mask=5)
| 0010 (1 << 1)
──────
0111 (result = 7)
New mask = 7 = 0111
Now position 1 is ON!
Used numbers: {4, 3, 2}
Formula: mask | (1 << i)
Turns ON bit i (marks as used)
Visual summary:
Before: 0101 (nums 4,2 used)
After: 0111 (nums 4,3,2 used)
Putting It Together: Our State
OLD STATE (hard to memoize):
(index, buckets[])
buckets[] = which bucket has what
Example: [3, 5, 2] means bucket0=3, bucket1=5, bucket2=2
NEW STATE (easy to memoize):
(mask, currentSum)
mask = which numbers are used (in completed buckets)
currentSum = sum of bucket we're currently filling
Example:
nums = [4, 3, 2, 1], target = 5
State: (mask=0001, currentSum=4)
Meaning:
- nums[0]=4 is used (in some completed bucket)
- We're filling current bucket
- Current bucket has sum = 4 so far
- Need 1 more to reach target 5
State: (mask=1001, currentSum=0)
Meaning:
- nums[0]=4 and nums[3]=1 are used
- They completed a bucket! [4,1]=5 ✓
- Now currentSum=0 means starting a NEW bucket
- Need to pick from unused: {3, 2}
The mask + currentSum perfectly captures our state! ✓
📐 Function Definition (Bitmask Approach)
Function Signature:
boolean canPartition(int mask, int currentSum, int target, int[] nums, Boolean[] memo)
What does this function compute?
- Input mask: Bitmask showing which numbers are already used
- Input currentSum: Sum of current bucket being filled
- Input target: Target sum each bucket should reach
- Input nums: The array of numbers
- Input memo: Memoization array indexed by mask
- Output: true if we can partition remaining numbers into equal buckets
The Recursive Structure:
Base case:
if mask == (1 << n) - 1: // All bits set = all numbers used
return true // All buckets complete!
If currentSum == target: // Current bucket full!
Bucket complete! Start new bucket
return canPartition(mask, 0, target, nums, memo)
Check memo:
if memo[mask] != null:
return memo[mask]
Try each unused number:
for i = 0 to n-1:
if (mask & (1 << i)) == 0: // number i NOT used
if currentSum + nums[i] <= target:
newMask = mask | (1 << i) // mark as used
if canPartition(newMask, currentSum + nums[i], ...):
memo[mask] = true
return true
memo[mask] = false
return false
Why this works:
We fill buckets ONE BY ONE:
Bucket 1:
Start with currentSum = 0
Add numbers until currentSum = target
Mark those numbers as used in mask
Bucket 2:
Reset currentSum = 0
Add unused numbers until currentSum = target
Mark those as used too
Continue until all numbers used (mask = 1111...1)
The bitmask tracks which numbers are in completed buckets!
We don't care WHICH bucket, just if they're used! ✓
💡 Intuition with Visual Example
Fill buckets sequentially:
target = 5
nums = [4, 3, 2, 1]
index: 0 1 2 3
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
FILLING BUCKET 1
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
State: (mask=0000, currentSum=0)
Used: none, Current bucket: empty
Try nums[0]=4:
State: (mask=0001, currentSum=4)
Used: {4}, Current bucket: [4]
Need 1 more to reach 5
Try nums[3]=1:
State: (mask=1001, currentSum=5)
currentSum == target!
Bucket 1 complete: [4,1] = 5 ✓
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
FILLING BUCKET 2
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
State: (mask=1001, currentSum=0)
Used: {4,1} (in bucket 1), Current bucket: empty
Available: {3,2}
Try nums[1]=3:
State: (mask=1011, currentSum=3)
Used: {4,3,1}, Current bucket: [3]
Need 2 more
Try nums[2]=2:
State: (mask=1111, currentSum=5)
currentSum == target!
Bucket 2 complete: [3,2] = 5 ✓
State: (mask=1111, ...)
mask == 1111 (all used!)
Base case: return true! ✓
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Final partition:
Bucket 1: [4,1] = 5
Bucket 2: [3,2] = 5
The bitmask elegantly tracked which numbers were used!
📝 Implementation
class Solution {
public boolean canPartitionKSubsets(int[] nums, int k) {
int total = 0;
for (int num : nums) {
total += num;
}
if (total % k != 0) return false;
int target = total / k;
int n = nums.length;
for (int num : nums) {
if (num > target) return false;
}
// Optimization: sort descending
Arrays.sort(nums);
reverse(nums);
// memo[mask] = can we partition numbers NOT in mask?
Boolean[] memo = new Boolean[1 << n];
return backtrackBitmask(0, 0, target, nums, memo);
}
private boolean backtrackBitmask(int mask, int currentSum, int target,
int[] nums, Boolean[] memo) {
int n = nums.length;
// Base case: all numbers used
if (mask == (1 << n) - 1) {
return true;
}
// Current bucket complete! Start new bucket
if (currentSum == target) {
// Check memo before recursing
if (memo[mask] != null) {
return memo[mask];
}
boolean result = backtrackBitmask(mask, 0, target, nums, memo);
memo[mask] = result;
return result;
}
// Check memo for current state
if (memo[mask] != null) {
return memo[mask];
}
// Try each unused number
for (int i = 0; i < n; i++) {
// Is number i already used?
if ((mask & (1 << i)) != 0) continue;
// Would this exceed target?
if (currentSum + nums[i] > target) continue;
// Try using this number
int newMask = mask | (1 << i);
if (backtrackBitmask(newMask, currentSum + nums[i], target, nums, memo)) {
memo[mask] = true;
return true;
}
}
memo[mask] = false;
return false;
}
private void reverse(int[] arr) {
int i = 0, j = arr.length - 1;
while (i < j) {
int temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
i++;
j--;
}
}
}
🔍 Detailed Dry Run: nums = [4,3,2,1], k = 2, target = 5
n = 4
Initial: mask = 0000 (decimal 0), currentSum = 0
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
BITMASK BACKTRACKING
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
backtrackBitmask(mask=0, currentSum=0):
All available: [4,3,2,1]
Try i=0 (num=4):
Check if used: (0 & 1) = 0 ✓ not used
0 + 4 = 4 <= 5 ✓
newMask = 0 | 1 = 1 (binary: 0001)
backtrackBitmask(mask=1, currentSum=4):
Used: {4}, Available: {3,2,1}
Try i=1 (num=3):
Check: (1 & 2) = 0 ✓ not used
4 + 3 = 7 > 5 ✗
Try i=2 (num=2):
Check: (1 & 4) = 0 ✓ not used
4 + 2 = 6 > 5 ✗
Try i=3 (num=1):
Check: (1 & 8) = 0 ✓ not used
4 + 1 = 5 <= 5 ✓
newMask = 1 | 8 = 9 (binary: 1001)
backtrackBitmask(mask=9, currentSum=5):
currentSum == target!
Bucket 1 complete: [4,1] = 5 ✓
Check memo[9]? null
Start new bucket:
backtrackBitmask(mask=9, currentSum=0):
Used: {4,1}, Available: {3,2}
Try i=1 (num=3):
Check: (9 & 2) = 0 ✓ not used
0 + 3 = 3 <= 5 ✓
newMask = 9 | 2 = 11 (binary: 1011)
backtrackBitmask(mask=11, currentSum=3):
Used: {4,3,1}, Available: {2}
Try i=2 (num=2):
Check: (11 & 4) = 0 ✓ not used
3 + 2 = 5 <= 5 ✓
newMask = 11 | 4 = 15 (binary: 1111)
backtrackBitmask(mask=15, currentSum=5):
currentSum == target!
Check memo[15]? null
backtrackBitmask(mask=15, currentSum=0):
mask == 15 == (1<<4)-1
Base case: all used!
return true ✓
memo[15] = true
return true ✓
return true ✓
memo[11] = true
return true ✓
memo[9] = true
return true ✓
return true ✓
return true ✓
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
RESULT: true ✓
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Partition:
Bucket 1: {4,1} = 5 (mask transition: 0→1→9)
Bucket 2: {3,2} = 5 (mask transition: 9→11→15)
Memo entries saved:
memo[9] = true (after bucket 1, can complete rest)
memo[11] = true (after adding 3, can add 2)
memo[15] = true (all numbers used successfully)
Why Memoization Works Here
KEY INSIGHT:
Once we've computed memo[mask],
the answer is ALWAYS the same!
Doesn't matter:
- Which bucket we're filling (1st, 2nd, 3rd, ...)
- What order we used numbers
- How we reached this mask
Only matters:
- Which numbers are used (mask)
- Where we are in filling current bucket (implicit)
We memoize when currentSum == 0
(right before starting a new bucket)
This captures the state perfectly! ✓
Example:
If memo[9] = true, it means:
"Given that {4,1} are used,
we CAN complete the remaining buckets"
Next time we see mask=9, instant answer! ✓
📊 Complexity Analysis
Time Complexity: O(2^n × n)
States: 2^n possible masks
For each state: try n numbers
Total: O(2^n × n)
With memoization:
Each state computed ONCE
For n=16:
2^16 × 16 = 1,048,576 operations
Feasible! ✓
Space Complexity: O(2^n)
Memo array: 2^n entries
Recursion stack: O(n)
Total: O(2^n)
🔵 Approach 4: Bottom-Up DP - Deriving from Top-Down
🎯 Can We Do Bottom-Up?
Our top-down uses:
memo[mask] = can we partition numbers NOT in mask?
Can we build this bottom-up?
Challenge:
The recursion fills buckets ONE BY ONE
Order matters: currentSum resets when bucket fills
This makes bottom-up TRICKY!
But we can still do it! Let's derive step by step...
📐 Understanding the State Transition
TOP-DOWN (what we have):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
memo[mask] = can we complete remaining buckets?
starting with currentSum = 0
When currentSum reaches target:
→ Bucket done
→ Reset to 0
→ Continue with same mask
BOTTOM-UP (what we'll build):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
dp[mask] = can we partition numbers in mask into complete buckets
(possibly with incomplete bucket at currentSum)
We also need:
remainder[mask] = currentSum when we have mask
Build from smaller masks to larger masks!
🧒 Deriving Base Cases (ELI5)
BASE CASE in TOP-DOWN:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
if (mask == (1 << n) - 1): // All numbers used
return true
This means: "If all numbers are used, we're done!"
BOTTOM-UP TRANSLATION:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Start from mask = 0 (no numbers used)
Build up to mask = (1 << n) - 1 (all used)
dp[0] = true (no numbers used yet, valid starting point)
remainder[0] = 0 (no current sum)
Goal: compute dp[(1<<n)-1]
If dp[(1<<n)-1] = true → answer is true!
📝 Bottom-Up Implementation
class Solution {
public boolean canPartitionKSubsets(int[] nums, int k) {
int total = 0;
for (int num : nums) {
total += num;
}
if (total % k != 0) return false;
int target = total / k;
int n = nums.length;
for (int num : nums) {
if (num > target) return false;
}
// dp[mask] = can we partition numbers in mask validly?
// remainder[mask] = current bucket's sum for this mask
boolean[] dp = new boolean[1 << n];
int[] remainder = new int[1 << n];
// Base case
dp[0] = true;
remainder[0] = 0;
// Build from smaller masks to larger
for (int mask = 0; mask < (1 << n); mask++) {
if (!dp[mask]) continue; // Skip invalid states
// Try adding each unused number
for (int i = 0; i < n; i++) {
// Is number i already used in mask?
if ((mask & (1 << i)) != 0) continue;
int newMask = mask | (1 << i);
int newRemainder = (remainder[mask] + nums[i]) % target;
// Can we add this number?
if (remainder[mask] + nums[i] <= target) {
dp[newMask] = true;
remainder[newMask] = newRemainder;
}
}
}
return dp[(1 << n) - 1];
}
}
🔍 Detailed Dry Run: nums = [4,2], k = 2, target = 3
n = 2
target = 3
Initial:
dp[0] = true
remainder[0] = 0
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
BUILD DP TABLE
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Process mask = 0 (binary: 00):
dp[0] = true, remainder[0] = 0
Available: {4, 2}
Try i=0 (num=4):
Not used: ✓
remainder[0] + 4 = 0 + 4 = 4
4 > target (3)? YES ✗
Can't add!
Try i=1 (num=2):
Not used: ✓
remainder[0] + 2 = 0 + 2 = 2
2 <= target? YES ✓
newMask = 0 | 2 = 2 (binary: 10)
newRemainder = (0 + 2) % 3 = 2
dp[2] = true
remainder[2] = 2
Process mask = 1 (binary: 01):
dp[1] = false → skip
Process mask = 2 (binary: 10):
dp[2] = true, remainder[2] = 2
Used: {2}, Available: {4}
Try i=0 (num=4):
Not used: ✓
remainder[2] + 4 = 2 + 4 = 6
6 > target? YES ✗
Can't add!
Try i=1 (num=2):
Already used ✗
Process mask = 3 (binary: 11):
dp[3] = false
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
RESULT: dp[3] = false ✗
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Can't partition [4,2] into 2 groups of sum 3
(4 > 3, impossible!)
Why Bottom-Up is Harder Here
Bottom-up DP works BUT:
- Need to track remainder carefully
- Order of processing matters
- Less intuitive than top-down
Top-down is MUCH cleaner for this problem!
- Natural recursion matches bucket-filling
- currentSum tracked naturally
- Memoization straightforward
RECOMMENDATION:
Use top-down (Approach 3) for this problem! ✓
Bottom-up possible but unnecessarily complex
📊 Complexity
Time: O(2^n × n)
Space: O(2^n)
Clarity: Lower than top-down
📊 Complete Approach Comparison
┌────────────────────────────────────────────────────────────────┐
│ PARTITION K EQUAL SUBSETS - APPROACH COMPARISON │
├──────────────────┬──────────┬──────────┬──────────┬────────────┤
│ Approach │ Time │ Space │ Clarity │ Interview │
├──────────────────┼──────────┼──────────┼──────────┼────────────┤
│ Simple Backtrack │ O(k^n) │ O(n) │ High │ Start │
│ Optimized BT │ O(k^n)* │ O(n) │ High │ Good │
│ Bitmask Memo │ O(2^n×n) │ O(2^n) │ Medium │ Best ✓ │
│ Bottom-Up DP │ O(2^n×n) │ O(2^n) │ Low │ Optional │
└──────────────────┴──────────┴──────────┴──────────┴────────────┘
* Much faster with pruning in practice
Constraint n <= 16: Bitmask DP perfect!
💻 Complete Working Code
import java.util.Arrays;
class Solution {
public boolean canPartitionKSubsets(int[] nums, int k) {
return bitmaskDP(nums, k);
}
// Approach 3: Bitmask DP (Top-Down) - O(2^n × n), RECOMMENDED ✓
private boolean bitmaskDP(int[] nums, int k) {
int total = 0;
for (int num : nums) {
total += num;
}
if (total % k != 0)
return false;
int target = total / k;
int n = nums.length;
for (int num : nums) {
if (num > target)
return false;
}
Arrays.sort(nums);
reverse(nums);
Boolean[] memo = new Boolean[1 << n];
return backtrackBitmask(0, 0, target, nums, memo);
}
private boolean backtrackBitmask(int mask, int currentSum, int target, int[] nums, Boolean[] memo) {
int n = nums.length;
if (mask == (1 << n) - 1) {
return true;
}
if (currentSum == target) {
if (memo[mask] != null) {
return memo[mask];
}
boolean result = backtrackBitmask(mask, 0, target, nums, memo);
memo[mask] = result;
return result;
}
if (memo[mask] != null) {
return memo[mask];
}
for (int i = 0; i < n; i++) {
if ((mask & (1 << i)) != 0)
continue;
if (currentSum + nums[i] > target)
continue;
int newMask = mask | (1 << i);
if (backtrackBitmask(newMask, currentSum + nums[i], target, nums, memo)) {
memo[mask] = true;
return true;
}
}
memo[mask] = false;
return false;
}
// Approach 4: Bottom-Up DP - O(2^n × n), less intuitive
private boolean bottomUpDP(int[] nums, int k) {
int total = 0;
for (int num : nums) {
total += num;
}
if (total % k != 0)
return false;
int target = total / k;
int n = nums.length;
for (int num : nums) {
if (num > target)
return false;
}
boolean[] dp = new boolean[1 << n];
int[] remainder = new int[1 << n];
dp[0] = true;
remainder[0] = 0;
for (int mask = 0; mask < (1 << n); mask++) {
if (!dp[mask])
continue;
for (int i = 0; i < n; i++) {
if ((mask & (1 << i)) != 0)
continue;
int newMask = mask | (1 << i);
if (remainder[mask] + nums[i] <= target) {
dp[newMask] = true;
remainder[newMask] = (remainder[mask] + nums[i]) % target;
}
}
}
return dp[(1 << n) - 1];
}
// Approach 2: Optimized Backtracking - O(k^n), practical
private boolean optimizedBacktracking(int[] nums, int k) {
int total = 0;
for (int num : nums) {
total += num;
}
if (total % k != 0)
return false;
int target = total / k;
Arrays.sort(nums);
reverse(nums);
if (nums[0] > target)
return false;
int[] buckets = new int[k];
return backtrackOptimized(0, buckets, nums, target);
}
private boolean backtrackOptimized(int index, int[] buckets, int[] nums, int target) {
if (index == nums.length) {
return true;
}
int num = nums[index];
for (int i = 0; i < buckets.length; i++) {
if (buckets[i] + num > target) {
continue;
}
buckets[i] += num;
if (backtrackOptimized(index + 1, buckets, nums, target)) {
return true;
}
buckets[i] -= num;
if (buckets[i] == 0)
break;
}
return false;
}
// Approach 1: Simple Backtracking - O(k^n), what we discover first
private boolean simpleBacktracking(int[] nums, int k) {
int total = 0;
for (int num : nums) {
total += num;
}
if (total % k != 0)
return false;
int[] buckets = new int[k];
return recursive(nums, k, 0, buckets);
}
private boolean recursive(int[] nums, int k, int index, int[] buckets) {
if (index == nums.length) {
for (int i = 1; i < buckets.length; i++) {
if (buckets[i] != buckets[i - 1]) {
return false;
}
}
return true;
}
for (int i = 0; i < k; i++) {
buckets[i] += nums[index];
if (recursive(nums, k, index + 1, buckets)) {
return true;
}
buckets[i] -= nums[index];
}
return false;
}
private void reverse(int[] arr) {
int i = 0, j = arr.length - 1;
while (i < j) {
int temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
i++;
j--;
}
}
public static void main(String[] args) {
Solution s = new Solution();
System.out.println(s.canPartitionKSubsets(new int[] { 4, 3, 2, 3, 5, 2, 1 }, 4) == true);
System.out.println(s.canPartitionKSubsets(new int[] { 1, 2, 3, 4 }, 3) == false);
System.out.println(s.canPartitionKSubsets(new int[] { 4, 3, 2, 1 }, 2) == true);
}
}
🔑 Key Insights
The Complete Journey
1. Start Simple:
Try all assignments → O(k^n)
2. Add Optimizations:
Check target, sort, skip empties
→ Still O(k^n) but faster
3. Hit Wall:
Can't memoize buckets[] array!
4. Transform:
Use bitmask to track used numbers
→ Can memoize! O(2^n × n)
5. Success:
Top-down bitmask DP optimal!
Bottom-up possible but complex
Bitmask Power
Why bitmask is elegant:
✓ Compresses state into single int
✓ Fast bit operations
✓ Perfect for memoization
✓ Works great for n ≤ 16
This technique appears in MANY hard problems!
🎪 Memory Aid
"k=2 → DP, k>2 → Backtracking!"
"Bitmask = which numbers used (1=yes, 0=no)!"
"Check bit: (mask & (1<<i)) != 0"
"Set bit: mask | (1<<i)"
"Top-down cleaner than bottom-up!" ✨
📝 Quick Revision Notes
🎯 Core Concept
Problem: Partition into k equal sum subsets
Check: total % k == 0, target = total/k
Simple: Backtracking O(k^n)
Optimal: Bitmask DP O(2^n × n) ✓
⚡ Bitmask Quick Reference
Check if i used: (mask & (1 << i)) != 0
Mark i as used: mask | (1 << i)
Example: nums = [4,3,2,1]
mask = 5 = 0101 → {4,2} used
mask = 11 = 1011 → {4,3,1} used
mask = 15 = 1111 → all used
Implementation
import java.util.Arrays;
public class Solution {
public boolean canPartitionKSubsets(int[] nums, int k) {
int[] buckets = new int[k];
// return recursive(nums, k, 0, buckets);
// Optimal 1: sort the array in descending order so that we can have early exits
Arrays.sort(nums);
reverse(nums);
// Optimal 2: currently we reach to end and checking if each bucket has same
// value or not. Instead, lets check for early exit.
// target is for each bucket post that any number should not be included in that
// bucket.
int total = Arrays.stream(nums).sum();
if (total % k != 0 || nums[0] > total / k) {
return false;
}
// return recursive_optimal(nums, k, 0, buckets, total / k);
// step 1: consider [4,3,2,1]
// we track memo for tracking which numbers (at which indices) are used already
// "1111" => all numbers are used
// "1010" => 4 and 2 are used
// "10000" => no number used. Why "10000" instead of "0000"?
// You cannot have a 0 sized index. So, have 1 before all 0s.
Boolean[] memo = new Boolean[1 << nums.length];
return topDown(nums, total / k, 0, 0, memo);
}
private boolean topDown(int[] nums, int k, int mask, int currentSum, Boolean[] memo) {
// step 7: base case
if (mask == (1 << nums.length) - 1) {
// for example, mask = 1111 (15, all used) for [4,3,2,1]
// 1<<4 => 10000
// (1<<4)-1 => 10000(16)-1=15
return true;
}
// step 8: if at anytime, any bucket got satisfied with
// any combination tried earlier. Try now with other buckets and numbers
if (currentSum == k) {
if (memo[mask] != null) {
return memo[mask];
}
memo[mask] = topDown(nums, k, mask, 0, memo);
}
// step 9
if (memo[mask] != null) {
return memo[mask];
}
// step 2: try each number to put in a bucket
for (int i = 0; i < nums.length; i++) {
// step 3: check if number at index i is used
// mask = 0011 and if i want to use number at index i
// 0011 & 0001 => 1 != 0
if ((mask & (1 << i)) != 0) {
continue; // used. Proceed for next number
}
// step 4: can adding this number is useful or not
if (nums[i] + currentSum > k) {
continue;
}
// step 5: update mask
int updatedMask = mask | (1 << i);
// step 6: try new target
if (topDown(nums, k, updatedMask, currentSum + nums[i], memo)) {
return memo[mask] = true;
}
}
return memo[mask] = false;
}
private void reverse(int[] nums) {
int start = 0;
int end = nums.length - 1;
while (start <= end) {
int temp = nums[start];
nums[start] = nums[end];
nums[end] = temp;
start++;
end--;
}
}
private boolean recursive_optimal(int[] nums, int k, int index, int[] buckets, int target) {
// step 3: base case
// check if all buckets have same value, then true.
// optimal 5: no need to check all buckets now
if (index == nums.length) {
return true;
}
// step 1: try to keep number at index at each bucket and proceed to next index
boolean res = false;
for (int i = 0; i < k; i++) {
// step 2a: try number at index in bucket i
// Optimal 3: whether adding this number is useful or not
// skip it if not useful
if (buckets[i] + nums[index] > target) {
continue;
}
buckets[i] = buckets[i] + nums[index];
res = recursive_optimal(nums, k, index + 1, buckets, target);
if (res) {
return true;
}
// step 2b: remove number at index in bucket i and try next i
buckets[i] = buckets[i] - nums[index];
// optimal 4: if adding to empty bucket is not useful, adding to any other
// bucket is not useful. Hence, skip all other buckets.
if (buckets[i] == 0) {
break;
}
}
return false;
}
private boolean recursive(int[] nums, int k, int index, int[] buckets) {
// step 3: base case
// check if all buckets have same value, then true.
if (index == nums.length) {
for (int i = 1; i < buckets.length; i++) {
if (buckets[i] != buckets[i - 1]) {
return false;
}
}
return true;
}
// step 1: try to keep number at index at each bucket and proceed to next index
boolean res = false;
for (int i = 0; i < k; i++) {
// step 2a: try number at index in bucket i
buckets[i] = buckets[i] + nums[index];
res = res | recursive(nums, k, index + 1, buckets);
if (res) {
return true;
}
// step 2b: remove number at index in bucket i and try next i
buckets[i] = buckets[i] - nums[index];
}
return false;
}
public static void main(String[] args) {
Solution s = new Solution();
System.out.println(s.canPartitionKSubsets(new int[] { 4, 3, 2, 3, 5, 2, 1 }, 4) == true);
System.out.println(s.canPartitionKSubsets(new int[] { 1, 2, 3, 4 }, 3) == false);
}
}
🎉 Congratulations!
You've mastered a HARD problem with: - ✅ Natural progression from simple to optimal - ✅ INTUITIVE bitmask explanation (from scratch!) - ✅ Complete optimizations with reasoning - ✅ Both top-down and bottom-up DP - ✅ All content preserved and improved!
Ready for the next challenge! 🚀✨