MATLAB – Sudoku Solver

Update: This was my first crack at this problem, and while it works for the two given puzzles, it breaks on other puzzles. Please see Revision 2.0 for an updated version of the code.

I went on travel last week; flew from the east coast to the west coast, non-stop. It’s a pretty long flight, about 5 hours. To amuse myself, I started doing the sudoku puzzle from the in-flight magazine. As I was finishing up the easy puzzle, I got this crazy idea to write up my own sudoku solver. “It can’t possibly be difficult, right?” I thought to myself. After all, solving a sudoku puzzle is just applied set theory; finding subsets and narrowing down the possibilities.

So once they announced that we were at a high enough altitude to use portable electronics, I whipped out my trusty laptop and went to work. Here’s the result.

main

This is just a main file that you can use to test out the program. Puzzle 1 is the warm up puzzle from the back of the magazine. Puzzle 2 is a slightly harder puzzle I found on the internet after I landed.

clear
clc

puzzle_1 = ...
    [0 3 0  0 0 0  0 1 0; ...
     5 0 0  0 6 0  0 9 2; ...
     6 0 2  9 0 5  8 0 0; ...
     ...
     0 9 3  7 0 0  0 2 5; ...
     2 0 0  4 0 9  0 0 8; ...
     4 8 0  0 0 3  1 6 0; ...
     ...
     0 0 1  6 0 4  2 0 3; ...
     3 7 0  0 2 0  0 0 4; ...
     0 2 0  0 0 0  0 5 1];

answer_1 = ...
    [7 3 9  8 4 2  5 1 6; ...
     5 4 8  1 6 7  3 9 2; ...
     6 1 2  9 3 5  8 4 7; ...
     ...
     1 9 3  7 8 6  4 2 5; ...
     2 6 5  4 1 9  7 3 8; ...
     4 8 7  2 5 3  1 6 9; ...
     ...
     8 5 1  6 9 4  2 7 3; ...
     3 7 6  5 2 1  9 8 4; ...
     9 2 4  3 7 8  6 5 1];

[soln_1 iters_1] = sudokuSolver(puzzle_1);

puzzle_2 = ...
    [0 0 5  4 2 9  3 0 0; ...
     0 7 0  0 0 0  0 6 0; ...
     0 9 8  0 0 0  2 4 0; ...
     ...
     0 0 0  2 3 1  0 0 0; ...
     0 0 0  0 0 0  0 0 0; ...
     0 0 0  7 6 5  0 0 0; ...
     ...
     0 6 7  0 0 0  9 3 0; ...
     0 5 0  0 0 0  0 2 0; ...
     0 0 9  6 1 8  7 0 0];

answer_2 = ...
    [6 1 5  4 2 9  3 8 7; ...
     4 7 2  8 5 3  1 6 9; ...
     3 9 8  1 7 6  2 4 5; ...
     ...
     7 4 6  2 3 1  5 9 8; ...
     5 2 1  9 8 4  6 7 3; ...
     9 8 3  7 6 5  4 1 2; ...
     ...
     8 6 7  5 4 2  9 3 1; ...
     1 5 4  3 9 7  8 2 6; ...
     2 3 9  6 1 8  7 5 4];
 
[soln_2 iters_2] = sudokuSolver(puzzle_2);

sudokuSolver

This next hunk of code is really just a wrapper for the actual code. This code just runs solveStep until the puzzle no longer changes. After this point, the wrapper code ends and spits out your solution.

function [soln n] = sudokuSolver(A)

% Input: 9x9 matrix, unknowns should be populated with 0.

n = 0;

puzzle = num(A);

done = 0;
while ~done
    [puzzle changed] = solveStep(puzzle);
    
    if ~changed
        done = 1;
    end
    
    n = n + 1;
end

soln = puzzle.puzz;

end

num

This final piece of code is the workhorse. Basically, I decided to organize things using MATLAB’s object oriented structure. There are only two properties to a num object: a puzzle, puzz and a possible, poss. puzz is a matrix which stores the values of the puzzle and poss is a cell matrix which stores the possible values of each location. Each time you run solveStep the two matrices are updated and you end up one step closer to the solution.

classdef num
    properties
        puzz
        poss
    end
    
    methods
        
        % Constructor
        % 
        %   Creates a num object which holds the puzzle matrix and the
        %   possibilies cell array. The puzzle matrix is simply a 9x9
        %   matrix with 0s for the unknown locations. The possibilies
        %   matrix is a 9x9 cell array which a subset of the integers 1
        %   through 9.
        function obj = num(puzz)
            % Assign the given puzzle matrix
            obj.puzz = puzz;
            
            % Assign all possibilies
            all_poss = [1 2 3 4 5 6 7 8 9];
            obj.poss = cell(9, 9);
            for ii = 1:9
                for jj = 1:9
                    obj.poss{ii, jj} = all_poss;
                end
            end
            
            % Update the possibilies cell array
            for ii = 1:9
                obj = updatePoss(obj, @solvedUpdate, ii);
            end
        end
        
        % Solve Step
        %
        %   After the object has been initialized, this function is called
        %   repeatedly to check the possibilities cell array. If only one
        %   possibily exists in a cell, then that possibility is written
        %   into the puzzle matrix. After writing the puzzle matrix, the
        %   specific row, column, and square of the possibilities matrix is 
        %   updated.
        function [obj changed] = solveStep(obj)
            changed = 0;
            
            for ii = 1:9
                for jj = 1:9
                    if isempty(obj.poss{ii, jj})
                        continue
                    end
                    
                    obj.poss{ii, jj} = checkOnlyPoss(obj, ii, jj);
                    
                    if numel(obj.poss{ii, jj}) == 1
                        obj.puzz(ii, jj) = obj.poss{ii, jj};
                        obj = updatePoss(obj, @solvedUpdate, ii, jj);
                        changed = 1;
                        return
                    end
                end
            end
            
            for ii = 1:9
                obj = updatePoss(obj, @subsetUpdate, ii);
            end
        end
        
        % Check Only Possible
        %
        %   This function checks the possibility cell array for the only
        %   possible values. Meaning, if the number 5 can only be placed in
        %   one specific place in a row, then it will flag that row. Once
        %   flagged, the current possibilities cell is replaced by the only
        %   possible value. Immediately after this function returns,
        %   isolated possibilities are written into the puzzle matrix
        function poss_out = checkOnlyPoss(obj, ii, jj)
            poss_out = obj.poss{ii, jj};
            
            if isempty(obj.poss{ii, jj})
                return
            end
            
            row_poss = {obj.poss{ii, :}};
            col_poss = {obj.poss{:, jj}};
            
            [sqr, sqr_num, ind_rows, ind_cols] = getSqr(obj, ii, jj);
            sqr_poss = {obj.poss{ind_rows, ind_cols}};
            
            vals = obj.poss{ii, jj};
            for mm = 1:numel(vals)
                val = vals(mm);
                
                found_row = 0;
                found_col = 0;
                found_sqr = 0;
                
                for nn = 1:9
                    % Check rows
                    if nn ~= jj && ~found_row
                        if ~isempty(intersect(row_poss{nn}, val))
                            found_row = 1;
                        end
                    end
                    
                    % Check cols
                    if nn ~= ii && ~found_col
                        if ~isempty(intersect(col_poss{nn}, val))
                            found_col = 1;
                        end
                    end
                    
                    % Check sqrs
                    if nn ~= sqr_num && ~found_sqr
                        if ~isempty(intersect(sqr_poss{nn}, val))
                            found_sqr = 1;
                        end
                    end
                end
                
                if ~found_row || ~found_col || ~found_sqr
                    poss_out = val;
                end
            end
        end
        
        % Update Possibilies
        % 
        %   Wrapper function. Takes in the object plus one or two
        %   variables. If one variable is given, then each row, column, and
        %   square is processed by their number. If two variables are
        %   given, then only a specific row, column, and square is
        %   processed.
        function obj = updatePoss(obj, func, varargin)
            switch nargin
                case 3
                    obj = updatePossRow(obj, func, varargin{1});
                    obj = updatePossCol(obj, func, varargin{1});
                    obj = updatePossSqr(obj, func, varargin{1});
                case 4
                    obj = updatePossRow(obj, func, varargin{1});
                    obj = updatePossCol(obj, func, varargin{2});
                    obj = updatePossSqr(obj, func, varargin{1}, ...
                        varargin{2});
            end
                
        end
        
        % Update Possibilies Row
        %
        %   Wrapper function. Calls updatePossBlock to update a row.
        function obj = updatePossRow(obj, func, row_num)
            temp_out = ...
                func(...
                    obj.puzz(row_num, : ), ...
                    {obj.poss{row_num, :}});
            
            for ii = 1:9
                obj.poss{row_num, ii} = temp_out{ii};
            end
        end
        
        % Update Possibilies Column
        %
        %   Wrapper function. Calls updatePossBlock to update a column.
        function obj = updatePossCol(obj, func, col_num)
            temp_out = ...
                func(...
                    obj.puzz(:, col_num), ...
                    {obj.poss{:, col_num}}); %#ok<*CCAT>
            
            for ii = 1:9
                obj.poss{ii, col_num} = temp_out{ii};
            end
        end
        
        % Update Possibilities Square
        %
        %   Wrapper function. Calls updatePossBlock to update a square.
        function obj = updatePossSqr(obj, func, varargin)
            if nargin == 3
                [sqr ind_rows ind_cols] = ...
                    getSqr(obj, varargin{1});
            elseif nargin == 4
                [sqr ind_rows ind_cols] = ...
                    getSqr(obj, varargin{1}, varargin{2});
            end
            
            temp_out = func(...
                sqr, ...
                {obj.poss{ind_rows, ind_cols}});
            
            for ii = 1:3
                for jj = 1:3
                    obj.poss{ind_rows(ii), ind_cols(jj)} = ...
                        temp_out{ii, jj};
                end
            end
        end
        
        % getSqr
        %
        %   This function is used to get either the square number or the
        %   set of rows and columns of a square.
        function [sqr varargout] = getSqr(obj, varargin)
            
            switch nargin
                case 2
                    sqr_num = varargin{1};
                    if sqr_num <= 3
                        ind_rows = 1:3;
                        
                        if sqr_num == 1
                            ind_cols = 1:3;
                        elseif sqr_num == 2
                            ind_cols = 4:6;
                        elseif sqr_num == 3
                            ind_cols = 7:9;
                        end
                        
                    elseif sqr_num <= 6
                        ind_rows = 4:6;
                        
                        if sqr_num == 4
                            ind_cols = 1:3;
                        elseif sqr_num == 5
                            ind_cols = 4:6;
                        elseif sqr_num == 6
                            ind_cols = 7:9;
                        end
                        
                    elseif sqr_num <= 9
                        ind_rows = 7:9;
                        
                        if sqr_num == 7
                            ind_cols = 1:3;
                        elseif sqr_num == 8
                            ind_cols = 4:6;
                        elseif sqr_num == 9
                            ind_cols = 7:9;
                        end
                    end
                        
                case 3
                    row_num = varargin{1};
                    col_num = varargin{2};
                    
                    if row_num <= 3
                        if col_num <= 3
                            sqr_num = 1;
                            ind_rows = 1:3;
                            ind_cols = 1:3;
                        elseif col_num <= 6
                            sqr_num = 2;
                            ind_rows = 1:3;
                            ind_cols = 4:6;
                        elseif col_num <= 9
                            sqr_num = 3;
                            ind_rows = 1:3;
                            ind_cols = 7:9;
                        else
                            error('Invalid Column Number')
                        end
                    elseif row_num <= 6
                        if col_num <= 3
                            sqr_num = 1;
                            ind_rows = 4:6;
                            ind_cols = 1:3;
                        elseif col_num <= 6
                            sqr_num = 2;
                            ind_rows = 4:6;
                            ind_cols = 4:6;
                        elseif col_num <= 9
                            sqr_num = 3;
                            ind_rows = 4:6;
                            ind_cols = 7:9;
                        else
                            error('Invalid Column Number')
                        end
                    elseif row_num <= 9
                        if col_num <= 3
                            sqr_num = 1;
                            ind_rows = 7:9;
                            ind_cols = 1:3;
                        elseif col_num <= 6
                            sqr_num = 2;
                            ind_rows = 7:9;
                            ind_cols = 4:6;
                        elseif col_num <= 9
                            sqr_num = 3;
                            ind_rows = 7:9;
                            ind_cols = 7:9;
                        else
                            error('Invalid Column Number')
                        end
                    else
                        error('Invalid Row Number')
                    end
            end
            
            sqr = obj.puzz(ind_rows, ind_cols);
            
            if nargout == 2
                varargout{1} = sqr_num;
            elseif nargout == 3
                varargout{1} = ind_rows;
                varargout{2} = ind_cols;
            elseif nargout == 4
                varargout{1} = sqr_num;
                varargout{2} = ind_rows;
                varargout{3} = ind_cols;
            end
        end
    end
end

% Solved Update
%
%   This generic function is used to update rows, columns, or squares. All
%   the current values of the puzzle matrix are removed from the
%   possibilities matrix.
function poss_out = solvedUpdate(puzz_in, poss_in)
    [num_rows num_cols] = size(puzz_in);
    puz = reshape(puzz_in, 9, 1);
    pos = reshape(poss_in, 9, 1);
    
    ind_filled = 1:9;
    ind_filled = ind_filled(puz ~= 0);
    for ii = 1:numel(ind_filled)
        pos{ind_filled(ii)} = [];
    end
    
    curr_vals = sort(puz(puz ~= 0));
    
    for ii = 1:numel(curr_vals)
        for jj = 1:9
            po = pos{jj};
            if isempty(po)
                continue
            end
            pos{jj} = po(po ~= curr_vals(ii));
        end
    end
    
    poss_out = reshape(pos, num_rows, num_cols);
end

% Subset Update
%
%   This function searches for subsets of repeated numbers within the
%   possibility matrix. If two sets of the same nummbers are found, then 
%   those two numbers can be removed from the possibility matrix.
function poss_out = subsetUpdate(puzz_in, poss_in)
    poss = reshape(poss_in, 9, 1);
    [num_rows num_cols] = size(puzz_in);
    
    unique_poss = cell(1);
    unique_cntr = ones(1);
    n = 1;
    
    % Count the number of unique sets
    for ii = 1:9
        if isempty(poss_in{ii});
            continue
        end
        
        if n == 1
            unique_poss{n} = poss_in{ii};
            n = n + 1;
            continue
        end
        
        found = 0;
        for nn = 1:numel(unique_poss)
            if numel(intersect(unique_poss{nn}, poss_in{ii})) == ...
                    numel(unique_poss{nn})
                
                unique_cntr(nn) = unique_cntr(nn) + 1;
                found = 1;
            end
        end
        
        if ~found
            temp_poss = {unique_poss{:} poss_in{ii}};
            unique_poss = temp_poss;
            temp_cntr = [unique_cntr 1];
            unique_cntr = temp_cntr;
        end
        
        n = n + 1;
    end
    
    % For the unique sets with more than 2 hits, if the number of hits is
    % equal to the size of the unique set, then remove values from the
    % other possibilities
    if any(unique_cntr > 2)
        ind_uniques = unique_cntr > 2;
        
        unique_sets = {unique_poss{ind_uniques}};
        for ii = 1:numel(unique_sets)
            if numel(unique_sets{ii}) == unique_cntr(ii)
                for jj = 1:9
                    if numel(intersect(poss{jj}, unique_sets{ii})) ...
                            == numel(unique_sets{ii})
                        poss{jj} = intersect(poss{jj}, unique_sets{ii});
                    end
                end
            end
        end
    end
    
    poss_out = reshape(poss, num_rows, num_cols);
end

Final Thoughts

After ~3 hours, the battery on my computer started to die. So I packed it away and finished it up on the plane ride back home. At the end of the first plane ride, it was able to solve easy puzzles. During the ride home, I added the subset functionality to solve the harder puzzles.

I’ll probably comment later on about the different algorithms and thought processes involved in this code. But for now, I just wanted to share it.

2 thoughts on “MATLAB – Sudoku Solver

Leave a comment