[MATLAB] Sudoku Solver Rev 4.0

As promised, here’s the code for the newest revision of my Sudoku Solver.

main.m

clear
clc

% American Airlines - 20 January 2014

gentle = ...
    [0 0 0  9 5 0  0 0 4; ...
     0 0 6  3 0 2  0 8 0; ...
     5 0 0  0 0 0  0 0 7; ...
     ...
     0 9 0  6 0 0  0 0 0; ...
     2 3 5  0 0 0  6 9 1; ...
     0 0 0  0 0 9  0 7 0; ...
     ...
     1 0 0  0 0 0  0 0 9; ...
     0 0 0  7 0 6  5 0 0; ...
     7 0 0  0 4 1  0 0 0];

sudokuSolver(gentle);

% moderate = ...
%     [0 8 0  0 5 0  0 1 3; ...
%      2 0 0  0 0 0  0 0 6; ...
%      0 0 1  9 0 0  0 0 2; ...
%      ...
%      0 0 8  0 1 0  0 0 0; ...
%      0 0 5  4 6 8  7 0 0; ...
%      0 0 0  0 9 0  1 0 0; ...
%      ...
%      4 0 0  0 0 6  8 0 0; ...
%      3 0 0  0 0 0  0 0 7; ...
%      8 9 0  0 7 0  0 2 0];
% 
% sudokuSolver(moderate);


% diabolical = ...
%     [9 0 0  5 0 1  0 0 6; ...
%      0 0 0  0 8 0  0 3 0; ...
%      0 5 0  7 4 0  0 0 0; ...
%      ...
%      0 0 7  0 0 5  0 0 0; ...
%      5 0 0  0 6 0  0 0 7; ...
%      0 0 0  1 0 0  2 0 0; ...
%      ...
%      0 0 0  0 3 8  0 4 1; ...
%      0 6 0  0 0 0  0 0 0; ...
%      1 0 0  6 0 7  0 0 2];
% 
% sudokuSolver(diabolical);

sudokuSolver.m

function varargout = sudokuSolver(A)

% Input: 9x9 matrix, unknowns should be populated with 0.

puzzle = num(A);

puzzle = solve(puzzle);

if nargout == 0
    disp(puzzle)
elseif nargout == 1
    varargout{1} = puzzle.puzz;
elseif nargout == 2
    varargout{1} = puzzle.puzz;
    varargout{2} = puzzle.poss;
end

end

num.m

classdef num
    properties
        puzz_orig   % The original puzzle input
        puzz        % The puzzle output, 9x9 matrix of values 0 to 9
        poss        % The possibilities matrix, 9x9 cell matrix
                    %   corresponding to each cell in the puzzle matrix
        
        guess       % The current guess
        guess_row   % Current guess's row
        guess_col   % Current guess's column
        guess_orig  % num object of puzzle before the guess
        guessed     % Vector holding the guesses we have tried
    end
    
    methods
        
        % Constructor
        % 
        %   Creates a num object which holds the puzzle matrix and the
        %   possibilies cell array. The puzzle matrix is simply a 9x9
        %   matrix with 0s for the unknown locations. The possibilies
        %   matrix is a 9x9 cell array which a subset of the integers 1
        %   through 9.
        function obj = num(puzz)
            % Assign the given puzzle matrix
            obj.puzz = puzz;
            obj.puzz_orig = puzz;
            
            % Assign all possibilies
            all_poss = [1 2 3 4 5 6 7 8 9];
            obj.poss = cell(9, 9);
            for ii = 1:9
                for jj = 1:9
                    obj.poss{ii, jj} = all_poss;
                end
            end
            
            % Update the possibilies cell array
            for ii = 1:9
                obj = updatePoss(obj, @solvedUpdate, ii);
            end
        end
        
        % Solve
        %
        %   After the object has been initialized, this function is called
        %   once to solve the puzzle. There are three main processes to
        %   solve the puzzle. The first process is a loop that seeks unique
        %   possibilities first and does subset operations after. If this
        %   fails then the second process seeks subset possibilities first
        %   and unique possibilities after. If these two processes fail,
        %   then the final process starts the guessing algorithm.
        function obj = solve(obj)
            
            obj = solveUniqueFirst(obj);
            
            % If we end up with unsolved cells, then we removed too many
            % sets in the wrong order. Start over by removing subsets first
            % and then unique sets after.
            if any(any(obj.puzz == 0))
                obj = num(obj.puzz_orig);
                obj = solveSubsetFirst(obj);
            end
            
            % At this point, should have solved the puzzle. Start guessing.
            if any(any(obj.puzz == 0))
                done_guessing = 0;
                obj = num(obj.puzz_orig);
                while ~done_guessing
                    obj = solveSingle(obj);
                    obj = solveGuess(obj);
                    guessed_obj = obj;
                    
                    obj = solveUniqueFirst(obj);
                    
                    if any(any(obj.puzz == 0))
                        obj = guessed_obj;
                        obj = solveSubsetFirst(obj);
                    end
                    
                    if all(all(obj.puzz ~= 0))
                        done_guessing = 1;
                    end
                end
            end
        end
        
        function obj = solveUniqueFirst(obj)
            n = 0;
            done = 0;
            while ~done
                [obj, single_changed] = solveSingle(obj);
                
                if ~single_changed
                    [obj, unique_changed] = solveUnique(obj);
                    
                    if ~unique_changed
                        [obj, subset_changed] = solveSubset(obj);
                        
                        if ~subset_changed
                            done = 1;
                        end
                    end
                end
                n = n + 1;
            end
        end
        
        function obj = solveSubsetFirst(obj)
            n = 0;
            done = 0;
            while ~done
                [obj, single_changed] = solveSingle(obj);

                if ~single_changed
                    [obj, subset_changed] = solveSubset(obj);

                    if ~subset_changed
                        [obj, unique_changed] = solveUnique(obj);

                        if ~unique_changed
                            done = 1;
                        end
                    end
                end
                n = n + 1;
            end
        end
        
        % Loops through every cell to check if only one possibility exists.
        function [obj, varargout] = solveSingle(obj)
            
            changed = 0;
            for ii = 1:9
                for jj = 1:9
                    if isempty(obj.poss{ii, jj})
                        continue
                    end
                    
                    obj.poss{ii, jj} = checkOnlyPoss(obj, ii, jj);
                    
                    if numel(obj.poss{ii, jj}) == 1
                        obj.puzz(ii, jj) = obj.poss{ii, jj};
                        [obj, solved_changed] = updatePoss(...
                            obj, ...
                            @solvedUpdate, ...
                            ii, ...
                            jj);
                        if solved_changed
                            changed = 1;
                        end
                    end
                end
            end
            
            if nargout == 2
                varargout{1} = changed;
            end
        end
        
        % Loops over all rows/columns/squares to check for subsets
        function [obj, varargout] = solveSubset(obj, varargin)
            
            if nargin == 2
                switch varargin{1}
                    case 'reverse'
                        iters = 9:-1:1;
                    case 'random'
                        iters = randperm(9);
                    otherwise
                        iters = 1:9;
                end
            else
                iters = 1:9;
            end
            
            changed = 0;
            for ii = iters
                [obj, subset_changed] = updatePoss(...
                    obj, ...
                    @subsetUpdate, ...
                    ii);
                
                if subset_changed
                    changed = 1;
                end
            end
            
            if nargout == 2
                varargout{1} = changed;
            end
        end
        
        % Loops over all rows/columns/squares to check for unique sets
        function [obj, varargout] = solveUnique(obj, varargin)
            
            if nargin == 2
                switch varargin{1}
                    case 'reverse'
                        iters = 9:-1:1;
                    case 'random'
                        iters = randperm(9);
                    otherwise
                        iters = 1:9;
                end
            else
                iters = 1:9;
            end
            
            changed = 0;
            for ii = iters
                [obj, unique_changed] = updatePoss(...
                    obj, ...
                    @uniqueUpdate, ...
                    ii);
                
                if unique_changed
                    changed = 1;
                end
            end
            
            if nargout == 2
                varargout{1} = changed;
            end
        end
        
        function obj = solveGuess(obj)
            if isempty(obj.guess_orig)
                % Make a first guess on a cell with a minimum number of
                % possibilities
                
                % Count the number of possibilities in each cell
                poss_num = zeros(9);
                for ii = 1:9
                    for jj = 1:9
                        poss_num(ii, jj) = numel(obj.poss{ii, jj});
                    end
                end
                
                % Find the minimum number of possibilities overall
                max_poss = max(max(poss_num(poss_num ~= 0)));
                
                % Create some variables to help separate the minim number
                % of possibilities
                ind_max = poss_num == max_poss;
                [col_mesh, row_mesh] = meshgrid(1:9, 1:9);
                guess_rows = row_mesh(ind_max);
                guess_cols = col_mesh(ind_max);
                
                % Take a random cell with the minimum number of guesses
                guess_num = randi(length(guess_rows));
                obj.guess_row = guess_rows(guess_num);
                obj.guess_col = guess_cols(guess_num);
                
                % Take a random possibility within the cell
                guess_poss = obj.poss{obj.guess_row, obj.guess_col};
                obj.guess = guess_poss(randi(length(guess_poss)));
                obj.guessed = obj.guess;
                obj.guess_orig = obj;
                
            else
                % Guessed wrong
                
                % Find all the other possibilities within the guess matrix
                guess_poss = ...
                    obj.guess_orig.poss{...
                        obj.guess_row, ...
                        obj.guess_col};
                guess_poss = setxor(guess_poss, obj.guessed);
                
                if isempty(guess_poss)
                    error('No more possible guesses')
                end
                
                % Take a random possibility, other than the one just tried
                obj.guess = guess_poss(randi(length(guess_poss)));
                obj.guessed = [obj.guessed obj.guess];
                
                % Reset the possibility and puzzle matricies to that before
                % our initial guess
                obj.poss = obj.guess_orig.poss;
                obj.puzz = obj.guess_orig.puzz;
            end
            
            % Set the only possibility for that cell as our guess, then
            % update the possibilities matrix
            obj.poss{obj.guess_row, obj.guess_col} = obj.guess;
            obj = solveSingle(obj);
        end
        
        % Check Only Possible
        %
        %   This function checks the possibility cell array for the only
        %   possible values. Meaning, if the number 5 can only be placed in
        %   one specific place in a row, then it will flag that row. Once
        %   flagged, the current possibilities cell is replaced by the only
        %   possible value. Immediately after this function returns,
        %   isolated possibilities are written into the puzzle matrix
        function poss_out = checkOnlyPoss(obj, ii, jj)
            poss_out = obj.poss{ii, jj};
            
            if isempty(obj.poss{ii, jj})
                return
            end
            
            row_poss = obj.poss(ii, :);
            col_poss = obj.poss(:, jj);
            
            [~, sqr_num, ind_rows, ind_cols] = getSqr(obj, ii, jj);
            sqr_poss = {obj.poss{ind_rows, ind_cols}};
            
            vals = obj.poss{ii, jj};
            for mm = 1:numel(vals)
                val = vals(mm);
                
                found_row = 0;
                found_col = 0;
                found_sqr = 0;
                
                for nn = 1:9
                    % Check rows
                    if nn ~= jj && ~found_row
                        if ~isempty(intersect(row_poss{nn}, val))
                            found_row = 1;
                        end
                    end
                    
                    % Check cols
                    if nn ~= ii && ~found_col
                        if ~isempty(intersect(col_poss{nn}, val))
                            found_col = 1;
                        end
                    end
                    
                    % Check sqrs
                    if nn ~= sqr_num && ~found_sqr
                        if ~isempty(intersect(sqr_poss{nn}, val))
                            found_sqr = 1;
                        end
                    end
                end
                
                if ~found_row || ~found_col || ~found_sqr
                    poss_out = val;
                end
            end
        end
        
        % Update Possibilies
        % 
        %   Wrapper function. Takes in the object plus one or two
        %   variables. If one variable is given, then each row, column, and
        %   square is processed by their number. If two variables are
        %   given, then only a specific row, column, and square is
        %   processed.
        function [obj, changed] = updatePoss(obj, func, varargin)
            switch nargin
                case 3
                    [obj, changed_row] = updatePossRow(...
                        obj, ...
                        func, ...
                        varargin{1});
                    [obj, changed_col] = updatePossCol(...
                        obj, ...
                        func, ...
                        varargin{1});
                    [obj, changed_sqr] = updatePossSqr(...
                        obj, ...
                        func, ...
                        varargin{1});
                case 4
                    [obj, changed_row] = updatePossRow(...
                        obj, ...
                        func, ...
                        varargin{1});
                    [obj, changed_col] = updatePossCol(...
                        obj, ...
                        func, ...
                        varargin{2});
                    [obj, changed_sqr] = updatePossSqr(...
                        obj, ...
                        func, ...
                        varargin{1}, ...
                        varargin{2});
            end
            
            changed = any([changed_row changed_col changed_sqr]);
        end
        
        % Update Possibilies Row
        %
        %   Wrapper function. Calls updatePossBlock to update a row.
        function [obj, changed] = updatePossRow(obj, func, row_num)
            [temp_out, changed] = ...
                func(...
                    obj.puzz(row_num, : ), ...
                    {obj.poss{row_num, :}});
            
            for ii = 1:9
                obj.poss{row_num, ii} = temp_out{ii};
            end
        end
        
        % Update Possibilies Column
        %
        %   Wrapper function. Calls updatePossBlock to update a column.
        function [obj, changed] = updatePossCol(obj, func, col_num)
            [temp_out, changed] = ...
                func(...
                    obj.puzz(:, col_num), ...
                    {obj.poss{:, col_num}}); %#ok
            
            for ii = 1:9
                obj.poss{ii, col_num} = temp_out{ii};
            end
        end
        
        % Update Possibilities Square
        %
        %   Wrapper function. Calls updatePossBlock to update a square.
        function [obj, changed] = updatePossSqr(obj, func, varargin)
            if nargin == 3
                [sqr, ind_rows, ind_cols] = ...
                    getSqr(obj, varargin{1});
            elseif nargin == 4
                [sqr, ind_rows, ind_cols] = ...
                    getSqr(obj, varargin{1}, varargin{2});
            end
            
            [temp_out, changed] = func(...
                sqr, ...
                {obj.poss{ind_rows, ind_cols}});
            
            for ii = 1:3
                for jj = 1:3
                    obj.poss{ind_rows(ii), ind_cols(jj)} = ...
                        temp_out{ii, jj};
                end
            end
        end
        
        % getSqr
        %
        %   This function is used to get either the square number or the
        %   set of rows and columns of a square.
        function [sqr, varargout] = getSqr(obj, varargin)
            
            switch nargin
                case 2
                    sqr_num = varargin{1};
                    if sqr_num <= 3
                        ind_rows = 1:3;
                        
                        if sqr_num == 1
                            ind_cols = 1:3;
                        elseif sqr_num == 2
                            ind_cols = 4:6;
                        elseif sqr_num == 3
                            ind_cols = 7:9;
                        end
                        
                    elseif sqr_num <= 6
                        ind_rows = 4:6;
                        
                        if sqr_num == 4
                            ind_cols = 1:3;
                        elseif sqr_num == 5
                            ind_cols = 4:6;
                        elseif sqr_num == 6
                            ind_cols = 7:9;
                        end
                        
                    elseif sqr_num <= 9
                        ind_rows = 7:9;
                        
                        if sqr_num == 7
                            ind_cols = 1:3;
                        elseif sqr_num == 8
                            ind_cols = 4:6;
                        elseif sqr_num == 9
                            ind_cols = 7:9;
                        end
                    end
                        
                case 3
                    row_num = varargin{1};
                    col_num = varargin{2};
                    
                    if row_num <= 3
                        if col_num <= 3
                            sqr_num = 1;
                            ind_rows = 1:3;
                            ind_cols = 1:3;
                        elseif col_num <= 6
                            sqr_num = 2;
                            ind_rows = 1:3;
                            ind_cols = 4:6;
                        elseif col_num <= 9
                            sqr_num = 3;
                            ind_rows = 1:3;
                            ind_cols = 7:9;
                        else
                            error('Invalid Column Number')
                        end
                    elseif row_num <= 6
                        if col_num <= 3
                            sqr_num = 1;
                            ind_rows = 4:6;
                            ind_cols = 1:3;
                        elseif col_num <= 6
                            sqr_num = 2;
                            ind_rows = 4:6;
                            ind_cols = 4:6;
                        elseif col_num <= 9
                            sqr_num = 3;
                            ind_rows = 4:6;
                            ind_cols = 7:9;
                        else
                            error('Invalid Column Number')
                        end
                    elseif row_num <= 9
                        if col_num <= 3
                            sqr_num = 1;
                            ind_rows = 7:9;
                            ind_cols = 1:3;
                        elseif col_num <= 6
                            sqr_num = 2;
                            ind_rows = 7:9;
                            ind_cols = 4:6;
                        elseif col_num <= 9
                            sqr_num = 3;
                            ind_rows = 7:9;
                            ind_cols = 7:9;
                        else
                            error('Invalid Column Number')
                        end
                    else
                        error('Invalid Row Number')
                    end
            end
            
            sqr = obj.puzz(ind_rows, ind_cols);
            
            if nargout == 2
                varargout{1} = sqr_num;
            elseif nargout == 3
                varargout{1} = ind_rows;
                varargout{2} = ind_cols;
            elseif nargout == 4
                varargout{1} = sqr_num;
                varargout{2} = ind_rows;
                varargout{3} = ind_cols;
            end
        end
        
        function disp(obj)
            for ii = 1:9
                str = '';
                for jj = 1:9
                    str = [str num2str(obj.puzz(ii, jj)) ' ']; %#ok
                    if mod(jj, 3) == 0
                        str = [str ' '];
                    end
                end
                disp(str)
                if mod(ii, 3) == 0
                    disp(' ')
                end
            end
        end
    end
end

% Solved Update
%
%   This generic function is used to update rows, columns, or squares. All
%   the current values of the puzzle matrix are removed from the
%   possibilities matrix.
function [poss_out, changed] = solvedUpdate(puzz_in, poss_in)
    changed = 0;
    [num_rows, num_cols] = size(puzz_in);
    puz = reshape(puzz_in, 9, 1);
    pos = reshape(poss_in, 9, 1);
    
    ind_filled = 1:9;
    ind_filled = ind_filled(puz ~= 0);
    for ii = 1:numel(ind_filled)
        pos{ind_filled(ii)} = [];
    end
    
    curr_vals = sort(puz(puz ~= 0));
    
    for ii = 1:numel(curr_vals)
        for jj = 1:9
            po = pos{jj};
            if isempty(po)
                continue
            end
            pos{jj} = po(po ~= curr_vals(ii));
            changed = 1;
        end
    end
    
    poss_out = reshape(pos, num_rows, num_cols);
end

% Subset Update
%
%   This function searches for subsets of repeated numbers within the
%   possibility matrix. If two sets of the same numbers are found, then 
%   those two numbers can be removed from the possibility matrix.
function [poss_out, changed] = subsetUpdate(puzz_in, poss_in)
    
    changed = 0;
    poss = reshape(poss_in, 9, 1);
    [num_rows, num_cols] = size(puzz_in);
    
    if ~any(puzz_in == 0)
        poss_out = reshape(poss, num_rows, num_cols);
        return
    end
    
    unique_poss = cell(1);
    subset_cntr = zeros(1);
    n = 1;
    
    % Count the number of unique sets
    for ii = 1:9
        if isempty(poss{ii});
            continue
        end
        
        if n == 1
            unique_poss{n} = poss{ii};
            n = n + 1;
            continue
        end
        
        found = 0;
        for nn = 1:numel(unique_poss)
            
            if ((numel(setdiff(poss{ii}, unique_poss{nn})) == 0) && ...
                (numel(unique_poss{nn}) == numel(poss{ii})))
            
                found = 1;
            end
        end
        
        if ~found
            
            temp_poss = {unique_poss{:} poss{ii}};
            unique_poss = temp_poss;
            
            temp_subset_cntr = [subset_cntr 0];
            subset_cntr = temp_subset_cntr;
            
        end
        
        n = n + 1;
    end
    
    for ii = 1:9
        for nn = 1:numel(unique_poss)
            if numel(setdiff(poss{ii}, unique_poss{nn})) == ...
               (numel(poss{ii}) - numel(unique_poss{nn}))
            
                subset_cntr(nn) = subset_cntr(nn) + 1;
            end
        end
    end
    
    if any(subset_cntr > 2)
        ind_uniques = subset_cntr > 2;
        unique_sets = {unique_poss{ind_uniques}};
        unique_subset_cntr = subset_cntr(ind_uniques);
        for ii = 1:numel(unique_sets)
            if numel(unique_sets{ii}) == unique_subset_cntr(ii)
                for jj = 1:9
                    
                    if isempty(poss{jj})
                        continue
                    end
                    
                    if numel(unique_sets{ii}) == numel(poss{jj})
                        if all(unique_sets{ii} == poss{jj})
                            continue
                        end
                    end
                    
                    poss_diff = setdiff(poss{jj}, unique_sets{ii});
                    if numel(poss_diff) == ...
                       numel(poss{jj}) - numel(unique_sets{ii})
                        
                        poss{jj} = unique_sets{ii};
                        changed = 1;
                    end
                    
                end
            end
        end
    end
    
    poss_out = reshape(poss, num_rows, num_cols);
end

% Unique Update
function [poss_out, changed] = uniqueUpdate(puzz_in, poss_in)
    
    changed = 0;
    poss = reshape(poss_in, 9, 1);
    [num_rows, num_cols] = size(puzz_in);
    
    if ~any(puzz_in == 0)
        poss_out = reshape(poss, num_rows, num_cols);
        return
    end
    
    unique_poss = cell(1);
    unique_cntr = ones(1);
    n = 1;
    
    % Count the number of unique sets
    for ii = 1:9
        if isempty(poss{ii});
            continue
        end
        
        if n == 1
            unique_poss{n} = poss{ii};
            n = n + 1;
            continue
        end
        
        found = 0;
        for nn = 1:numel(unique_poss)
            
            if ((numel(setdiff(poss{ii}, unique_poss{nn})) == 0) && ...
                (numel(unique_poss{nn}) == numel(poss{ii})))
               
                unique_cntr(nn) = unique_cntr(nn) + 1;
                found = 1;
            end
        end
        
        if ~found
            
            temp_poss = {unique_poss{:} poss{ii}};
            unique_poss = temp_poss;
            
            temp_cntr = [unique_cntr 1];
            unique_cntr = temp_cntr;
            
        end
        
        n = n + 1;
    end
    
    % For the unique sets with more than 2 hits, if the number of hits is
    % equal to the size of the unique set, then remove values from the
    % other possibilities
    if any(unique_cntr >= 2)
        ind_uniques = unique_cntr >= 2;
        
        unique_sets = {unique_poss{ind_uniques}};
        unique_set_cntr = unique_cntr(ind_uniques);
        for ii = 1:numel(unique_sets)
            
            if numel(unique_sets{ii}) == unique_set_cntr(ii)
                for jj = 1:9
                    
                    if isempty(poss{jj})
                        continue
                    end
                    
                    if numel(unique_sets{ii}) == numel(poss{jj})
                        if all(unique_sets{ii} == poss{jj})
                            continue
                        end
                    end
                    
                    poss_diff = setdiff(poss{jj}, unique_sets{ii});
                    if numel(poss_diff) < numel(poss{jj})
                        poss{jj} = poss_diff;
                        changed = 1;
                    end
                end
            end
            
        end
    end
    
    poss_out = reshape(poss, num_rows, num_cols);
end

Leave a comment