从Target Sum说起 - DP

从Target Sum说起 - DP

题目:You are given a list of non-negative integers, a1, a2, …, an, and a target, S. Now you have 2 symbols + and -. For each integer, you should choose one from + and - as its new symbol.

Find out how many ways to assign symbols to make sum of integers equal to target S.

Example 1:
Input: nums is [1, 1, 1, 1, 1], S is 3.
Output: 5
Explanation:

-1+1+1+1+1 = 3
+1-1+1+1+1 = 3
+1+1-1+1+1 = 3
+1+1+1-1+1 = 3
+1+1+1+1-1 = 3

There are 5 ways to assign symbols to make the sum of nums be target 3.
Note:
The length of the given array is positive and will not exceed 20.
The sum of elements in the given array will not exceed 1000.
Your output answer is guaranteed to be fitted in a 32-bit integer.

  1. DFS, 暴力搜索。这里因为限制了数目不超过20个数。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
public:
int findTargetSumWays(vector<int>& nums, int S) {
// vector<int> sign(nums.size(), 0);
int count = 0;
// assign(count, sign, nums, 0, S, 0);
dfs(count, nums, 0, S, 0);
return count;
}
void dfs(int& count, vector<int>& nums, int depth, int S, int cur_sum){
if(depth == nums.size()){
if(cur_sum == S){
count++;
}
return;
}
dfs(count, nums, depth+1, S, cur_sum + nums[depth]);
dfs(count, nums, depth+1, S, cur_sum - nums[depth]);
}
};
  1. 对暴力DP的优化, 48%

    或者也可以算是Memorization的DP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
public:
int findTargetSumWays(vector<int>& nums, int S) {
vector<unordered_map<int,int> > cache(nums.size());
return dfs(nums, 0, S, 0, cache);
}
int dfs(vector<int>& nums, int depth, int S, int cur_sum, vector<unordered_map<int,int> >& cache){
if(depth == nums.size()){
return cur_sum == S;
}
if(cache[depth].count(cur_sum)) return cache[depth][cur_sum];
int left = dfs(nums, depth+1, S, cur_sum + nums[depth], cache);
int right = dfs(nums, depth+1, S, cur_sum - nums[depth], cache);
return cache[depth][cur_sum] = left + right;
}
};
  1. Bottom-Up DP

    二维DP

    Define:

    S = sum($a_i$)

    当所有值取正,结果就是S, 相反所有值取负,结果是-S. 加上0,所有的取值空间是2*S + 1.

    这也是为什么DP会比DFS快的原因。DFS的搜索空间是2^n.

    当然因为数组是0 index的,所以在做的时候需要设置offset, 这里这是为S,相当于整体向正方向偏移offset长度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// DP Push, 59.61%
class Solution {
public:
int findTargetSumWays(vector<int>& nums, int S) {
int n = nums.size();
int sum = accumulate(nums.begin(), nums.end(), 0);
if(sum < S) return 0;
int offset = sum;
// first n items to achieve the sum
vector<vector<int> > dp(n + 1, vector<int>(2 * sum + 1, 0));
dp[0][offset] = 1;
//for(int i=0; i<n; i++){
// for(int j=nums[i]; j < 2 * sum + 1 - nums[i]; j++){
// if(dp[i][j]){
// // push
// dp[i+1][j-nums[i]] += dp[i][j];
// dp[i+1][j+nums[i]] += dp[i][j];
// }
// }
//}
for(int i=0; i<n; i++){
for(int j=0; j<2*sum + 1; j++){
// push
if(j-nums[i] >= 0)
dp[i+1][j-nums[i]] += dp[i][j];
if(j+nums[i] < 2*sum + 1)
dp[i+1][j+nums[i]] += dp[i][j];
}
}
return dp.back()[offset + S];
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// DP, pull, 59.61%
class Solution {
public:
int findTargetSumWays(vector<int>& nums, int S) {
int n = nums.size();
int sum = accumulate(nums.begin(), nums.end(), 0);
if(sum < S) return 0;
int offset = sum;
// first n items to achieve the sum
vector<vector<int> > dp(n + 1, vector<int>(2 * sum + 1, 0));
dp[0][offset] = 1;
for(int i=0; i<n; i++){
for(int j=0; j < 2 * sum + 1; j++){
// pull
if(j-nums[i] >= 0 && j+nums[i] < 2 * sum + 1)
dp[i+1][j] = dp[i][j-nums[i]] + dp[i][j+nums[i]];
else if(j-nums[i] < 0){
dp[i+1][j] = dp[i][j+nums[i]];
}
else{
dp[i+1][j] = dp[i][j-nums[i]];
}
}
}
return dp.back()[offset + S];
}
};

仔细观察Push 和 pull的区别,为什么j的值,在push里面可以从j=nums[i]; j<2*sum+1-nums[i]; j++即可。而pull中却会出问题?

一个直观的解释是,当你Push的时候,形状是正向三角的,形状类似于$\bigwedge$,范围是扩大的; 而pull的时候,形状是倒向的,$\bigvee$, 范围是缩小的,就是产生missing. 但是又有一个问题,即便是正向的,依然缺少一部分需要push到下一层的数据,比如j=0的时候,可以push到下一层的j+nums[i]. 所以为什么结果还是正确的呢?这是因为有一个offset的存在。使得之前的数据都是0,加与不加的结果一致。当然为了简化写的方法,其实也可以多循环一些,把j的范围从0到2*sum+1. 这样的好处在于,不容易遗漏,坏处在于速度稍有影响,多做几次没用的循环。

  1. 对上面的DP的优化,因为事实上只使用到了俩行的数据,因此可以使用滚动数组来减少空间复杂度。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// DP Push, 68.71%
class Solution {
public:
int findTargetSumWays(vector<int>& nums, int S) {
int n = nums.size();
int sum = accumulate(nums.begin(), nums.end(), 0);
if(sum < S) return 0;
int offset = sum;
int kMaxN = 2 * sum + 1;
// first n items to achieve the sum
vector<int> dp(kMaxN);
dp[offset] = 1;
for(int i=0; i<n; i++){
vector<int> tmp(dp.size(), 0);
for(int j = nums[i]; j < kMaxN - nums[i]; j++){
if(dp[j]){
tmp[j-nums[i]] += dp[j];
tmp[j+nums[i]] += dp[j];
}
}
swap(tmp, dp);
}
return dp[offset + S];
}
};

除了滚动数组之外,这里因为整个dp数组是sparse的,所以还可以直接用hashtable来优化空间。不需要全部存储。而且当使用hashtable的时候就不用考虑offset了,因为hashtable的key可以为负数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// DP Push, 50.86%
class Solution {
public:
int findTargetSumWays(vector<int>& nums, int S) {
int n = nums.size();
vector<unordered_map<int,int> > dp(n+1);
dp[0][0] = 1;
for(int i=0; i<n; i++){
for(const auto& item : dp[i]){
int sum = item.first, cnt = item.second;
// push
dp[i+1][sum - nums[i]] += cnt;
dp[i+1][sum + nums[i]] += cnt;
}
}
return dp.back()[S];
}
};

这个当然也可以继续优化,因为不需要n+1个hashtable,只需要2个就可以。略。

  1. 接下来需要换一个思路来解决这个问题。将原问题转换为subset sum.

    这个思路我觉得很巧妙。定义集合P为所有取正号的数的集合,T为负号的数的集合。

    set(P) $\bigcup$ set(T) = A, set(P) $\bigcap$ set(T) = $\emptyset$

    sum(P) - sum(T) = S

    sum(P) + sum(P) -sum(T) + sum(T) = S + sum(P) + sum(T)

    2*sum(P) = S + sum(A)

    sum(P) = (S + sum(A)) / 2

​ 也就是说,此时的问题已经转换为了在原集合中找到一个子集,这个子集的和满足目标与整体集合和的一半。这里的/2也可以作为剪枝条件。此时这个问题已经转变为了经典的01背包问题,一个集合,每个数或者在P或者不在P. 可能的取值只有0和1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class Solution {
public:
int findTargetSumWays(vector<int>& nums, int S) {
S = abs(S); // 这里正负对结果没有影响,完全可以在结果前面加一个负号,改变正负关系
int sum = accumulate(nums.begin(), nums.end(), 0);
if(sum < S || (S + sum) % 2) return 0;
int target = (S + sum) / 2;
vector<int> dp(target + 1, 0);
dp[0] = 1;
for(int num : nums){
for(int j = target; j >= num; j--){
dp[j] += dp[j - num]; // 因为in-place改变,所以只能用pull, 并且被改变的必须是序号比较大的,之后不会再用的
}
}
return dp[target];
}
};

这个题需要多看多想。Fight on~

Reference

http://zxi.mytechroad.com/blog/dynamic-programming/leetcode-494-target-sum/