跳转至

搜索 Sudoku

前置知识

深度优先搜索

目标

Sudoku系列题目,剪枝

第一题,POJ2676Sudoku

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

const int N = 15;

bool row[N][N], col[N][N], grid[N][N];
int g[N][N], T;
char s[N][N];

void init() {
    memset(row, false, sizeof row);
    memset(col, false, sizeof col);
    memset(grid, false, sizeof grid);
    memset(g, 0, sizeof g);

    for (int i = 0; i < 9; i++) scanf("%s", s[i]);
    for (int i = 0; i < 9; i++)
        for (int j = 0; j < 9; j++) {
            g[i][j] = s[i][j] - '0';

            if (g[i][j]){
                int t = i / 3 * 3 + j / 3;
                row[i][g[i][j]] = col[j][g[i][j]] = grid[t][g[i][j]] = true;
            }
        }
}

// 0-index, {0...8}, the end is 9
bool dfs(int i, int j) {
    if (i == 9) return true;

    bool flag = false;
    if (g[i][j]) {
        if (j == 8) flag = dfs(i + 1, 0);
        else flag = dfs(i, j + 1);

        return flag;
    }

    int t = i / 3 * 3 + j / 3;
    for (int num = 1; num <= 9; num ++) {
        if (row[i][num] || col[j][num] || grid[t][num]) continue;

        row[i][num] = col[j][num] = grid[t][num] = true;
        g[i][j] = num;

        if (j == 8) flag = dfs(i + 1, 0);
        else flag = dfs(i, j + 1);

        if (flag) return true;

        g[i][j] = 0;
        row[i][num] = col[j][num] = grid[t][num] = false;
    } 

    return false;
}

void print() {
    for (int i = 0; i < 9; i++) {
        for (int j = 0; j < 9; j++) cout << g[i][j];
        puts("");
    }
}

int main() {
    cin >> T;
    while (T--) {
        init();
        dfs(0, 0);
        print();
    }

    return 0;
}

第二题,POJ3074Sudoku

// Time Limit Exceeded
// using bianry system to record which number can be used
// but in the recursion tree, there are too many choises, so, TLE
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>

using namespace std;

const int N = 15, M = 1 << 9 + 10;

int row[N], col[N], grid[N];
int g[N][N];
int num[M];
string s;

int lowbit(int x) {
    return x & -x;
}

void f(int x) {
    for (int i = 8; i >= 0; i--)
        if (x >> i & 1) cout << 1;
        else cout << 0;
    puts("");
}

// 0-index, {0...8}, the end is 9
bool dfs(int i, int j) {
    if (i == 9) return true;

    bool flag = false;
    if (g[i][j]) {
        if (j == 8) flag = dfs(i + 1, 0);
        else flag = dfs(i, j + 1);

        return flag;
    }

    int t = i / 3 * 3 + j / 3;
    int status = row[i] & col[j] & grid[t];

    for (int mask = status; mask > 0; mask -= lowbit(mask)) {
        g[i][j] = num[lowbit(mask)];
        row[i] ^= lowbit(mask);
        col[j] ^= lowbit(mask);
        grid[t] ^= lowbit(mask);

        if (j == 8) flag = dfs(i + 1, 0);
        else flag = dfs(i, j + 1);

        if (flag) return true;

        grid[t] ^= lowbit(mask);
        col[j] ^= lowbit(mask);
        row[i] ^= lowbit(mask);
        g[i][j] = 0;
    }

    return false;
}

void print() {
    for (int i = 0; i < 9; i++) {
        for (int j = 0; j < 9; j++) cout << g[i][j];
    }
    puts("");
}

void init() {
    memset(g, 0, sizeof g);
    for (int i = 0; i < 9; i++) num[1 << i] = i + 1; // pretreatment, 000000001, write '1' on the ceil 
    for (int i = 0; i < 9; i++) row[i] = col[i] = grid[i] = (1 << 9) - 1; // 111111111, represent nine ceils can write

    int idx = 0;
    for (int i = 0; i < 9; i++)
        for (int j = 0; j < 9; j++) {
            if (s[idx] != '.') g[i][j] = s[idx] - '0';
            idx++;
        }

    for (int i = 0; i < 9; i++)
        for (int j = 0; j < 9; j++) {
            if (g[i][j]) {
                row[i] ^= (1 << (g[i][j] - 1));
                col[j] ^= (1 << (g[i][j] - 1));
                int t = i / 3 * 3 + j / 3;
                grid[t] ^= (1 << (g[i][j] - 1));
            }
        }
}

int main() {
    while (cin >> s) {
        if (s == "end") break;
        init();
        dfs(0, 0);
        print();
    }

    return 0;
}
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>

using namespace std;

const int N = 15, M = 1 << 9 + 10;

int row[N], col[N], grid[N], r, c, g;
string s;
int f[M];
int cnt[M]; // the binary number have how many '1'
int X[81], Y[81]; // 0~80 correspond to which row, which col 

int get_cnt(int x) {
    int tot = 0;
    while (x) {
        x -= (x & -x);
        tot++;
    }

    return tot;
}

void get(int i) {
    r = X[i], c = Y[i], g = r / 3 * 3 + c / 3;
}

void init() {
    for (int i = 0; i < (1 << 9); i++) cnt[i] = get_cnt(i);
    for (int i = 0; i < 9; i++) f[1 << i] = i;
    for (int i = 0; i < 81; i++) X[i] = i / 9, Y[i] = i % 9;
}

void press(int i, int k) {
    get(i);
    row[r] ^= (1 << k);
    col[c] ^= (1 << k);
    grid[g] ^= (1 << k);
}

bool dfs(int u) {
    if (u == 0) return true;

    // find the point has the minimal chioses
    int t, minn = 10;
    for (int i = 0; i < 81; i++) 
        if (s[i] == '.') {
            get(i);
            int w = row[r] & col[c] & grid[g];
            if (cnt[w] < minn) {
                minn = cnt[w];
                t = i;
            }
        }

    get(t);
    int w = row[r] & col[c] & grid[g]; // w is this point has how many choises of number can be used

    // solve subproblem
    while (w) {
        int now = f[w & -w]; // press which this number(using lowbit() )
        s[t] = '1' + now;
        press(t, now);

        if (dfs(u - 1)) return true;

        // backtracking
        press(t, now);
        s[t] = '.';

        w -= w & -w;
    }

    return false;
}

void Sudoku() {
    for (int i = 0; i < 9; i++) row[i] = col[i] = grid[i] = (1 << 9) - 1;

    int cnt = 0;
    for (int i = 0; i < 81; i++) 
        if (s[i] != '.') press(i, s[i] - '1');
        else cnt++;

    // the string has cnt point to press number
    dfs(cnt);
    cout << s << '\n';
}

int main() {
    init();

    while (cin >> s) {
        if (s == "end") break;
        Sudoku(); 
    }

    return 0;
}

第三题,POJ3076Sudoku

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>

using namespace std;

const int N = 20;

int mp[N][N]; // 记录数独的地图
int vis[N][N]; // 用二进制记录(i,j)这个位置可以有哪些数字可以填,0可填 1不可填
int cnt; // 记录一共填写了多少个点,到达256时代表完成填写
char s[N][N]; // 存储输入

void init() {
    cnt = 0;
    memset(vis, 0, sizeof vis);
    memset(mp, 0, sizeof mp);
}

// 在(x, y)位置填写 k
void fill(int x, int y, int k) {
    cnt++;
    mp[x][y] = k;

    for (int i = 0; i < 16; i++) {
        vis[x][i] |= 1 << (k - 1); // 标记这一行
        vis[i][y] |= 1 << (k - 1); // 标记这一列
    }

    // 所在九宫格的左上角的位置(gx, gy)
    // 标记这个九宫格的 9 个位置,k这个数字被使用了
    int gx = x / 4 * 4, gy = y / 4 * 4;
    for (int i = gx; i <= gx + 3; i++)
        for (int j = gy; j <= gy + 3; j++)
            vis[i][j] |= 1 << (k - 1);
}

int get_only_zero(int x) {
//  for(int i=0; x; i++) {
//      if(x&1) {
//          if(x>>1==0)return i;
//          return -1;//多个1则不满足剪枝条件
//      }
//      x>>=1;
//  }
//  return -1;


    int t = -1, cnt = 0;
    for (int i = 0; i < 16; i++) {
        if (((x >> i) & 1) == 0) {
            t = i, cnt++;
        }
    }

    if (cnt > 1) return -1; // 有多个位置是0
    if (cnt == 0) return -1; // 没有可以填充的
    return t; // 只有一个0,可以填充,返回位置
}

// 判断第x行,是否k是唯一可以填写的,如果是立即填上
// 返回-1表示k已经出现过,返回-2表示没有(k+1)填的位置
int hang(int x, int k) {
    int p = -1;
    for (int y = 0; y < 16; y++) {
        if (mp[x][y] == k + 1) return -1;
        if (mp[x][y] > 0) continue;
        if ((vis[x][y] & (1 << k)) == 0) {
            if (p != -1) return -1; // 不只一个位置可以填写,剪枝
            p = y;
        }
    }
    if (p != -1) return p; // 只有一个位置可以填写,立即填写上
    return -2;
}

int lie(int y, int k) {
    int p = -1;
    for (int x = 0; x < 16; x++) {
        if (mp[x][y] == k + 1) return -1;
        if (mp[x][y] > 0) continue;
        if ((vis[x][y] & (1 << k)) == 0) {
            if (p != -1) return -1; // 不只一个位置可以填写,剪枝
            p = x;
        }
    }
    if (p != -1) return p; // 只有一个位置可以填写,立即填写上
    return -2;
}

void gong(int gx, int gy, int k, int &x, int &y) {
    x = -2;

    for (int i = gx; i <= gx + 3; i++)
        for (int j = gy; j <= gy + 3; j++) {
            if (mp[i][j] == k + 1) {
                x = -1;
                return ;
            }
            if (mp[i][j] > 0) continue;
            if ((vis[i][j] & (1 << k)) == 0) {
                if (x != -2) {
                    x = -1;
                    return ;
                }

                x = i, y = j;
            }
        }
}

// 返回x中有多少个1
int count_one(int x) {
    int cnt = 0;
    while (x) {
        if (x & 1) cnt++;
        x >>= 1;
    }

    return cnt;
}

bool dfs() {
    if (cnt == 256) return true;

    // 遍历当前所有的空格
    // 如果某个位置只有1个字母可填,那么立即填上
    for (int i = 0; i < 16; i++)
        for (int j = 0; j < 16; j++)
            if (!mp[i][j]) {
                int k = get_only_zero(vis[i][j]);
                if (k != -1) fill(i, j, k + 1);
            }

    // 如果某行,只有1个字母可以填,立即填上
    for (int i = 0; i < 16; i++)
        for (int k = 0; k < 16; k++) {
            int t = hang(i, k);
            if (t == -2) return false;
            if (t != -1) fill(i, t, k + 1);
        }

    for (int j = 0; j < 16; j++)
        for (int k = 0; k < 16; k++) {
            int t = lie(j, k);
            if (t == -2) return false;
            if (t != -1) fill(t, j, k + 1);
        }

    for (int i = 0; i < 16; i += 4)
        for (int j = 0; j < 16; j += 4)
            for (int k = 0; k < 16; k++) {
                int x, y;
                gong(i, j, k, x, y);
                if (x == -2) return false;
                if (x != -1) fill(x, y, k + 1);
            }

    if (cnt == 256) return true;

    // 选择可填的字母最少得位置,枚举填写哪个字母作为分支
    // 把辅助数组的副本记录在局部变量里
    int t_cnt = cnt;
    int t_mp[N][N];
    int t_vis[N][N];

    int maxn = -1, mx, my;
    for (int i = 0; i < 16; i++)
        for (int j = 0; j < 16; j++) {
            t_mp[i][j] = mp[i][j];
            t_vis[i][j] = vis[i][j];

            if (mp[i][j] > 0) continue;

            int t = count_one(vis[i][j]);
            if (t > maxn) {
                maxn = t;
                mx = i, my = j;
            }
        }

    // (mx, my)这个位置上的1最多,那么就是0最少,就是可以填写的字母最少
    // 找到了这个位置后,枚举可以填写的字母
    for (int k = 0; k < 16; k++)
        if ((vis[mx][my] & (1 << k)) == 0) {
            fill(mx, my, k + 1);

            if (dfs()) return true;

            // 回溯,恢复现场
            cnt = t_cnt;
            for (int i = 0; i < 16; i++)
                for (int j = 0; j < 16; j++)
                    mp[i][j] = t_mp[i][j], vis[i][j] = t_vis[i][j];
        }

    return false;
}

void print() {
    for (int i = 0; i < 16; i++) {
        for (int j = 0; j < 16; j++)
            printf("%c", 'A' + mp[i][j] - 1);
        puts("");
    }

    puts("");
}

int main() {
    while (true) {
        init();

        for (int i = 0; i < 16; i++) {
            //if (!(cin >> s[i])) return 0; // 这样写不行 Compile Error
            //if (scanf("%s", s[i]) == EOF) return 0; // 这样写可以 
            //if (scanf("%s", s[i]) == 0) return 0;  // 这样写不行 Output Limit Exceeded
            if (scanf("%s", s[i]) != 1) return 0; // 这样写可以

            for (int j = 0; j < 16; j++)
                if (s[i][j] != '-') fill(i, j, s[i][j] - 'A' + 1);
        }

        dfs();
        print();
    }

    return 0;
}

总结

第一题,POJ2676Sudoku。bool dfs(int i, int j),代码框架是求解子问题,尝试9个数字,能不能填,能填就填

第二题,POJ3074Sudoku。

  • 第一种方法,用2676的方法+常数优化,就会超时了。用位运算,来记录行、列、小九宫格,能放什么数字,依然是用求解子问题的访问处理(用位运算来代替数组执行“对数独各个位置所填数字的记录”以及“可填性的检查与统计”,就是所谓的“常数优化”)。
  • 第二种方法,先找到这81个位置,哪一个位置的可填数字最少,就优先处理这个位置。这个位置上有几个数字可以填,再进行对应的求解子问题

第三题,POJ3076Sudoku。还需要加入更多的剪枝。对数独进行更加全面的可行性判定,尽早发现无解的分支进行回溯。我们加入以下的可行性剪枝:

  • 1、遍历当前所有的空格
  • (1)若某个位置A~P都不能填,立即回溯
  • (2)若某个位置只有1个字母可填,立即填上这个字母
  • 2、考虑所有行
  • (1)若某个字母不能填在该行的任何一个空位上,立即回溯
  • (2)若某个字母只能填在该行的某个空位上,立即填写
  • 3、考虑所有的列,执行相同操作
  • 4、考虑所有的十六宫格,执行相同操作
  • 5、选择可填的字母最少得位置,枚举填写哪个字母作为分支

参考

《进阶指南》