当前位置: 代码迷 >> 综合 >> 组合游戏系列1: Leetcode中的Minimax 和 Alpha Beta剪枝
  详细解决方案

组合游戏系列1: Leetcode中的Minimax 和 Alpha Beta剪枝

热度:81   发布时间:2024-01-28 13:17:07.0

本系列,我们来看看在一种常见的组合游戏——回合制棋盘类游戏中,如何用算法来解决问题。首先,我们会介绍并解决搜索空间较小的问题,引入经典的博弈算法和相关理论,最终实现在大搜索空间中的Deep RL近似算法。在此基础上可以理解AlphaGo的原理和工作方式。
本系列的第一篇,我们介绍3个Leetcode中的零和回合制游戏,从最初的暴力解法,到动态规划最终演变成博弈论里的经典算法: minimax 以及 alpha beta 剪枝。

  • 第一篇 Leetcode中的Minimax 和 Alpha Beta剪枝

  • 第二篇: 一些组合游戏的理论

  • 第三篇: 连接N个点 的OpenAI Gym GUI环境

  • 第四篇: 蒙特卡洛树搜索(MCTS)和时间差分学习(TD learning)

Leetcode 292 Nim Game (简单)

简单题 Leetcode 292 Nim Game。

你和你的朋友,两个人一起玩 Nim游戏:桌子上有一堆石头,每次你们轮流拿掉 1 - 3 块石头。 拿掉最后一块石头的人就是获胜者。你作为先手。
你们是聪明人,每一步都是最优解。 编写一个函数,来判断你是否可以在给定石头数量的情况下赢得游戏。

示例:
输入: 4
输出: false
解释: 如果堆中有 4 块石头,那么你永远不会赢得比赛;因为无论你拿走 1 块、2 块 还是 3 块石头,最后一块石头总是会被你的朋友拿走。

定义 f ( n ) f(n) 为有 n n 个石头并采取最优策略的游戏结果, f ( n ) f(n) 的值只有可能是赢或者输。考察前几个结果: f ( 1 ) = f ( 2 ) = f ( 3 ) = W i n f(1) = f(2) = f(3) = Win ,然后来计算 f ( 4 ) f(4) 。因为玩家采取最优策略(只要有一种走法让对方必输,玩家获胜),对于4来说,玩家能走的可能是拿掉1块、2块或3块,但是无论剩余何种局面,对方都是必赢,因此,4就是必输。总的说来,递归关系如下:
f ( n ) = ? ( f ( n ? 1 ) f ( n ? 2 ) f ( n ? 3 ) ) f(n) = \neg (f(n-1) \land f(n-2) \land f(n-3))

这个递归式可以直接翻译成Python 3代码

# TLE
# Time Complexity: O(exponential)
class Solution_BruteForce:def canWinNim(self, n: int) -> bool:if n <= 3:return Truefor i in range(1, 4):if not self.canWinNim(n - i):return Truereturn False

以上的递归公式和代码很像fibonacci数的递归定义和暴力解法,因此对应的时间复杂度也是指数级的,提交代码以后会TLE。下图画出了当n=7时的递归调用,注意 5 被扩展向下重复执行了两次,4重复了4次。

292 Nim Game 暴力解法调用图 n=7

我们采用和fibonacci一样的方式来优化算法:缓存较小n的结果以此来计算较大n的结果。 Python 中,我们可以只加一行lru_cache decorator,来取得这种动态规划效果,下面的代码将复杂度降到了 O ( N ) O(N)

# RecursionError: maximum recursion depth exceeded in comparison n=1348820612
# Time Complexity: O(N)
class Solution_DP:from functools import lru_cache@lru_cache(maxsize=None)def canWinNim(self, n: int) -> bool:if n <= 3:return Truefor i in range(1, 4):if not self.canWinNim(n - i):return Truereturn False

再来画出调用图:这次5和4就不再被展开重复计算,图中绿色的节点表示缓存命中。

292 Nim Game 动归解法调用图 n=7
但还是没有AC,因为当n=1348820612时,这种方式会导致栈溢出。再改成下面的循环版本,可惜还是TLE。
# TLE for 1348820612
# Time Complexity: O(N)
class Solution:def canWinNim(self, n: int) -> bool:if n <= 3:return Truelast3, last2, last1 = True, True, Truefor i in range(4, n+1):this = not (last3 and last2 and last1)last3, last2, last1 = last2, last1, thisreturn last1

由此看来,AC 版本需要低于 O ( n ) O(n) 的算法复杂度。上面的写法似乎暗示输赢有周期性的规律。事实上,如果将输赢按照顺序画出来,就马上得出规律了:只要 n m o d ?? 4 = 0 n \mod 4 = 0 就是输,否则赢。原因如下:当面临不能被4整除的数量时 4 k + i ( i = 1 , 2 , 3 ) 4k+i (i=1,2,3) ,一方总是可以拿走 i i 个,将 4 k 4k 留给对手,而对方下轮又将返回不能被4整除的数,如此循环往复,直到这一方有1, 2, 3 个,最终获胜。

输赢分布

最终AC版本,只有一句语句。

# AC
# Time Complexity: O(1)
class Solution:def canWinNim(self, n: int) -> bool:return not (n % 4 == 0)

Leetcode 486 Predict the Winner (中等)

中等难度题目: Leetcode 486 Predict the Winner.

给定一个表示分数的非负整数数组。 玩家1从数组任意一端拿取一个分数,随后玩家2继续从剩余数组任意一端拿取分数,然后玩家1拿,……。每次一个玩家只能拿取一个分数,分数被拿取之后不再可取。直到没有剩余分数可取时游戏结束。最终获得分数总和最多的玩家获胜。
给定一个表示分数的数组,预测玩家1是否会成为赢家。你可以假设每个玩家的玩法都会使他的分数最大化。

示例 1:
输入: [1, 5, 2]
输出: False
解释: 一开始,玩家1可以从1和2中进行选择。
如果他选择2(或者1),那么玩家2可以从1(或者2)和5中进行选择。如果玩家2选择了5,那么玩家1则只剩下1(或者2)可选。
所以,玩家1的最终分数为 1 + 2 = 3,而玩家2为 5。
因此,玩家1永远不会成为赢家,返回 False。

示例 2:
输入: [1, 5, 233, 7]
输出: True
解释: 玩家1一开始选择1。然后玩家2必须从5和7中进行选择。无论玩家2选择了哪个,玩家1都可以选择233。
最终,玩家1(234分)比玩家2(12分)获得更多的分数,所以返回 True,表示玩家1可以成为赢家。

对于当前玩家,他有两种选择:左边或者右边的数。定义 maxDiff(l, r) 为剩余子数组 [ l , r ] [l,r] 时,当前玩家能取得的最大分差,那么

maxDiff ? ( l , r ) = max ? { n u m s [ l ] ? maxDiff ? ( l + 1 , r ) n u m s [ r ] ? maxDiff ? ( l , r ? 1 ) \begin{aligned}\operatorname{maxDiff}(l, r) = \max\begin{cases}nums[l] - \operatorname{maxDiff}(l + 1, r) \\nums[r] - \operatorname{maxDiff}(l, r - 1)\end{cases} \end{aligned}

对应的时间复杂度可以写出递归式,显然是指数级的:
f ( n ) = 2 f ( n ? 1 ) = O ( 2 n ) f(n) = 2f(n-1) = O(2^n)

采用暴力解法可以AC,但运算时间很长,接近TLE边缘 (6300ms)。

# AC
# Time Complexity: O(2^N)
# Slow: 6300ms
from typing import Listclass Solution:def maxDiff(self, l: int, r:int) -> int:if l == r:return self.nums[l]return max(self.nums[l] - self.maxDiff(l + 1, r), self.nums[r] - self.maxDiff(l, r - 1))def PredictTheWinner(self, nums: List[int]) -> bool:self.nums = numsreturn self.maxDiff(0, len(nums) - 1) >= 0

从调用图也很容易看出是指数级的复杂度

486 Predict the Winner 暴力解法调用图 n=4

上图中我们有重复计算的节点,例如[1-2]节点被计算了两次。使用 lru_cache 大法,在maxDiff 上仅加了一句,就能以复杂度 O ( n 2 ) O(n^2) 和运行时间 43ms AC。

# AC
# Time Complexity: O(N^2)
# Fast: 43ms
from functools import lru_cache
from typing import Listclass Solution:@lru_cache(maxsize=None)def maxDiff(self, l: int, r:int) -> int:if l == r:return self.nums[l]return max(self.nums[l] - self.maxDiff(l + 1, r), self.nums[r] - self.maxDiff(l, r - 1))def PredictTheWinner(self, nums: List[int]) -> bool:self.nums = numsreturn self.maxDiff(0, len(nums) - 1) >= 0

动态规划解法调用图可以看出节点 [1-2] 这次没有被计算两次。

486 Predict the Winner 动归解法调用图 n=4

Leetcode 464 Can I Win (中等)

类似但稍有难度的题目 Leetcode 464 Can I Win。难点在于使用了位的状态压缩。

在 “100 game” 这个游戏中,两名玩家轮流选择从 1 到 10 的任意整数,累计整数和,先使得累计整数和达到 100 的玩家,即为胜者。
如果我们将游戏规则改为 “玩家不能重复使用整数” 呢?
例如,两个玩家可以轮流从公共整数池中抽取从 1 到 15 的整数(不放回),直到累计整数和 >= 100。
给定一个整数 maxChoosableInteger (整数池中可选择的最大数)和另一个整数 desiredTotal(累计和),判断先出手的玩家是否能稳赢(假设两位玩家游戏时都表现最佳)?
你可以假设 maxChoosableInteger 不会大于 20, desiredTotal 不会大于 300。

示例:
输入:
maxChoosableInteger = 10
desiredTotal = 11
输出:
false
解释:
无论第一个玩家选择哪个整数,他都会失败。
第一个玩家可以选择从 1 到 10 的整数。
如果第一个玩家选择 1,那么第二个玩家只能选择从 2 到 10 的整数。
第二个玩家可以通过选择整数 10(那么累积和为 11 >= desiredTotal),从而取得胜利.
同样地,第一个玩家选择任意其他整数,第二个玩家都会赢。

# AC
# Time Complexity: O:(2^m*m), m: maxChoosableInteger
class Solution:from functools import lru_cache@lru_cache(maxsize=None)def recurse(self, status: int, currentTotal: int) -> bool:for i in range(1, self.maxChoosableInteger + 1):if not (status >> i & 1):new_status = 1 << i | statusif currentTotal + i >= self.desiredTotal:return Trueif not self.recurse(new_status, currentTotal + i):return Truereturn Falsedef canIWin(self, maxChoosableInteger: int, desiredTotal: int) -> bool:self.maxChoosableInteger = maxChoosableIntegerself.desiredTotal = desiredTotalsum = maxChoosableInteger * (maxChoosableInteger + 1) / 2if sum < desiredTotal:return Falsereturn self.recurse(0, 0)

上面的代码算法复杂度为 O ( m 2 m ) O(m 2^m) ,m是maxChoosableInteger。由于所有状态的数量是 2 m 2^m ,对于每个状态,最多会尝试 m m 走法。

Minimax 算法

至此,我们AC了leetcode中的几道零和回合制博弈游戏。事实上,在这个领域有通用的算法:回合制博弈下的minimax。算法背景如下,两个玩家轮流玩,第一个玩家max的目的是将游戏的效用最大化,第二个玩家min则是最小化效用。比如,下面的节点表示玩家选取节点后游戏的效用,当两个玩家都能采取最优策略,Minimax 算法从底层节点来计算,游戏的结果是最终max 玩家会得到-7。

Wikipedia Minimax 例子

Minimax Python 3伪代码如下。

def minimax(node: Node, depth: int, maximizingPlayer: bool) -> int:if depth == 0 or is_terminal(node):return evaluate_terminal(node)if maximizingPlayer:value:int = ?∞for child in node:value = max(value, minimax(child, depth ? 1, False))return valueelse: # minimizing playervalue := +for child in node:value = min(value, minimax(child, depth ? 1, True))return value

Minimax: 486 Predict the Winner

我们知道486 Predict the Winner 是有minimax解法的,但如何具体实现,其难点在于如何定义合适的游戏价值或者效用。之前的解法中,我们定义maxDiff(l, r) 来表示当前玩家面临子区间 [ l , r ] [l, r] 时能取得的最大分差。对于minimax算法,max 玩家要最大化游戏价值,min玩家要最小化游戏价值。先考虑最简单情况即只有一个数x时,若定义max玩家在此局面下得到这个数时游戏价值为 +x,则min玩家为-x,即max玩家得到的所有数为正( + a 1 + a 2 + . . . = A +a_1 + a_2 + ... = A ),min玩家得到的所有数为负( ? b 1 ? b 2 ? . . . = ? B -b_1 - b_2 - ... = -B )。至此,max玩家的目标就是 m a x ( A ? B ) max(A-B) ,min玩家是 m i n ( A ? B ) min(A-B) 。有了精确的定义和优化目标,代码只需要套一下上面的模版。

# AC
from functools import lru_cache
from typing import Listclass Solution:# max_player: max(A - B)# min_player: min(A - B)@lru_cache(maxsize=None)def minimax(self, l: int, r: int, isMaxPlayer: bool) -> int:if l == r:return self.nums[l] * (1 if isMaxPlayer else -1)if isMaxPlayer:return max(self.nums[l] + self.minimax(l + 1, r, not isMaxPlayer),self.nums[r] + self.minimax(l, r - 1, not isMaxPlayer))else:return min(-self.nums[l] + self.minimax(l + 1, r, not isMaxPlayer),-self.nums[r] + self.minimax(l, r - 1, not isMaxPlayer))def PredictTheWinner(self, nums: List[int]) -> bool:self.nums = numsv = self.minimax(0, len(nums) - 1, True)return v >= 0
Minimax 486 调用图 nums=[1, 5, 2, 4]

Minimax: 464 Can I Win

该题目是很典型的此类游戏,即结果为赢输平,但是中间的状态没有直接对应的游戏价值。对于这样的问题,一般定义为,max玩家胜,价值 +1,min玩家胜,价值-1,平则0。下面的AC代码实现了 Minimax 算法。算法中针对两个玩家都有剪枝(没有剪枝无法AC)。具体来说,max玩家一旦在某一节点取得胜利(value=1),就停止继续向下搜索,因为这是他能取得的最好分数。同理,min玩家一旦取得-1也直接返回上层节点。这个剪枝可以泛化成 alpha beta剪枝算法。

# AC
class Solution:from functools import lru_cache@lru_cache(maxsize=None)# currentTotal < desiredTotaldef minimax(self, status: int, currentTotal: int, isMaxPlayer: bool) -> int:import mathif status == self.allUsed:return 0  # draw: no winnerif isMaxPlayer:value = -math.inffor i in range(1, self.maxChoosableInteger + 1):if not (status >> i & 1):new_status = 1 << i | statusif currentTotal + i >= self.desiredTotal:return 1  # shortcutvalue = max(value, self.minimax(new_status, currentTotal + i, not isMaxPlayer))if value == 1:return 1return valueelse:value = math.inffor i in range(1, self.maxChoosableInteger + 1):if not (status >> i & 1):new_status = 1 << i | statusif currentTotal + i >= self.desiredTotal:return -1  # shortcutvalue = min(value, self.minimax(new_status, currentTotal + i, not isMaxPlayer))if value == -1:return -1return value

Alpha-Beta 剪枝

在464 Can I Win minimax 算法代码实现中,我们发现有剪枝优化空间。对于每个节点,定义两个值alpha 和 beta,表示从根节点到目前局面时,max玩家保证能取得的最小值以及min玩家能保证取得的最大值。初始时,根节点alpha = ?∞ , beta = +∞,表示游戏最终的价值在区间 [?∞, +∞]中。在向下遍历的过程中,子节点先继承父节点的 alpha beta 值进而继承区间 [alpha, beta]。当子节点在向下遍历的时候同步更新alpha 或者 beta,一旦区间[alpha, beta]非法就立即向上返回。举个Wikimedia的例子来进一步说明:

  1. 根节点初始时: alpha = ?∞, beta = +∞

  2. 根节点,最左边子节点返回4后: alpha = 4, beta = +∞

  3. 根节点,中间子节点返回5后: alpha = 5, beta = +∞

  4. 最右Min节点(标1节点),初始时: alpha = 5, beta = +∞

  5. 最右Min节点(标1节点),第一个子节点返回1后: alpha = 5, beta = 1

此时,最右Min节点的alpha, beta形成了无效区间[5, 1],满足了剪枝条件,因此可以不用计算它的第二个和第三个子节点。如果剩余子节点返回值 > 1,比如2,由于这是个min节点,将会被已经到手的1替换。若其他子节点返回值 < 1,但由于min的父节点有效区间是[5, +∞],已经保证了>=5,小于5的值也会被忽略。

Wikimedia Alpha Beta 剪枝例子
Minimax Python 3伪代码如下
def alpha_beta(node: Node, depth: int, α: int, β: int, maximizingPlayer: bool) -> int:if depth == 0 or is_terminal(node):return evaluate_terminal(node)if maximizingPlayer:value: int = ?∞for child in node:value = max(value, alphabeta(child, depth ? 1, α, β, False))α = max(α, value)if α >= β:break # β cut-offreturn valueelse:value: int = +for child in node:value = min(value, alphabeta(child, depth ? 1, α, β, True))β = min(β, value)if β <= α:break # α cut-offreturn value

Alpha-Beta Pruning: 486 Predict the Winner

# AC
import math
from functools import lru_cache
from typing import Listclass Solution:def alpha_beta(self, l: int, r: int, curr: int, isMaxPlayer: bool, alpha: int, beta: int) -> int:if l == r:return curr + self.nums[l] * (1 if isMaxPlayer else -1)if isMaxPlayer:ret = self.alpha_beta(l + 1, r, curr + self.nums[l], not isMaxPlayer, alpha, beta)alpha = max(alpha, ret)if alpha >= beta:return alpharet = max(ret, self.alpha_beta(l, r - 1, curr + self.nums[r], not isMaxPlayer, alpha, beta))return retelse:ret = self.alpha_beta(l + 1, r, curr - self.nums[l], not isMaxPlayer, alpha, beta)beta = min(beta, ret)if alpha >= beta:return betaret = min(ret, self.alpha_beta(l, r - 1, curr - self.nums[r], not isMaxPlayer, alpha, beta))return retdef PredictTheWinner(self, nums: List[int]) -> bool:self.nums = numsv = self.alpha_beta(0, len(nums) - 1, 0, True, -math.inf, math.inf)return v >= 0

Alpha-Beta Pruning: 464 Can I Win

# AC
class Solution:from functools import lru_cache@lru_cache(maxsize=None)# currentTotal < desiredTotaldef alpha_beta(self, status: int, currentTotal: int, isMaxPlayer: bool, alpha: int, beta: int) -> int:import mathif status == self.allUsed:return 0  # draw: no winnerif isMaxPlayer:value = -math.inffor i in range(1, self.maxChoosableInteger + 1):if not (status >> i & 1):new_status = 1 << i | statusif currentTotal + i >= self.desiredTotal:return 1  # shortcutvalue = max(value, self.alpha_beta(new_status, currentTotal + i, not isMaxPlayer, alpha, beta))alpha = max(alpha, value)if alpha >= beta:return valuereturn valueelse:value = math.inffor i in range(1, self.maxChoosableInteger + 1):if not (status >> i & 1):new_status = 1 << i | statusif currentTotal + i >= self.desiredTotal:return -1  # shortcutvalue = min(value, self.alpha_beta(new_status, currentTotal + i, not isMaxPlayer, alpha, beta))beta = min(beta, value)if alpha >= beta:return valuereturn value

C++, Java, Javascript AC 486 Predict the Winner

最后介绍一种不同的DP实现:用C++, Java, Javascript 实现自底向上的DP解法来AC leetcode 486,当然其他语言没有Python的lru_cache大法。以下实现中,注意DP解的构建顺序,先解决小规模的问题,并在此基础上计算稍大的问题。值得一提的是,以下的循环写法严格保证了 n 2 n^2 次循环,但是自顶向下的计划递归可能会少于 n 2 n^2 次循环。

Java AC Code

// AC
class Solution {public boolean PredictTheWinner(int[] nums) {int n = nums.length;int[][] dp = new int[n][n];for (int i = 0; i < n; i++) {dp[i][i] = nums[i];}for (int l = n - 1; l >= 0; l--) {for (int r = l + 1; r < n; r++) {dp[l][r] = Math.max(nums[l] - dp[l + 1][r],nums[r] - dp[l][r - 1]);}}return dp[0][n - 1] >= 0;}
}

C++ AC Code

// AC
class Solution {
public:bool PredictTheWinner(vector<int>& nums) {int n = nums.size();vector<vector<int>> dp(n, vector<int>(n, 0));for (int i = 0; i < n; i++) {dp[i][i] = nums[i];}for (int l = n - 1; l >= 0; l--) {for (int r = l + 1; r < n; r++) {dp[l][r] = max(nums[l] - dp[l + 1][r], nums[r] - dp[l][r - 1]);}}return dp[0][n - 1] >= 0;}
};

Javascript AC Code

/*** @param {number[]} nums* @return {boolean}*/
var PredictTheWinner = function(nums) {const n = nums.length;const dp = new Array(n).fill().map(() => new Array(n));for (let i = 0; i < n; i++) {dp[i][i] = nums[i];}for (let l = n - 1; l >=0; l--) {for (let r = i + 1; r < n; r++) {dp[l][r] = Math.max(nums[l] - dp[l + 1][r],nums[r] - dp[l][r - 1]);}}return dp[0][n-1] >=0;
};

著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
同步发表在微信公众号,欢迎关注。
在这里插入图片描述

  相关解决方案