Machine Learning---Logistic回归

       本章节主要讲解Logistic回归的原理及其数学推导,Logistic有3种不同的表达形式,现在我就一一展开这几种不同的形式,以及它在分类中的效果。并比较这三种形式。 

下面分别写出这三种形式的损失函数:

技术分享

下面分别写出这三种损失函数的梯度形式:

技术分享

其中第一种形式和第三种形式是等价的,推导如下:

技术分享


Steepest descent

   前面章节已经讲过最速下降法的更新公式,如下:

技术分享

下面将给出代码这样容易理解:

main.m

<span style="font-family:Times New Roman;">[D,b] = load_data();
%%% run exp and log convex logistic regression %%%
x0 = randn(3,1);    % initial point
alpha = 10^-2;        % step length
x = grad_descent_exp_logistic(D,b,x0,alpha);
% Run log convex logistic regression
alpha = 10^-1;        % step length
y = grad_descent_log_logistic(D,b,x0,alpha);
%%% plot everything, pts and lines %%%
plot_all(D',b,x,y);</span>


load_data().m

<span style="font-family:Times New Roman;"> function [A,b] = load_data()
        data = load('exp_vs_log_data.mat');
        data = data.data;
        A = data(:,1:3);
        A = A';
        b = data(:,4);
    end</span>

grad_descent_exp_logistic.m

<span style="font-family:Times New Roman;">function x = grad_descent_exp_logistic(D,b,x0,alpha)
        % Initializations
        x = x0;
        iter = 1;
        max_its = 3000;
        grad = 1;
        m=22;
        while  norm(grad) > 10^-6 && iter < max_its
            
            % compute gradient
              sum=0;
            for i=1:22
                z=b(i)*(D(:,i)'*x);
                tmp1=exp(-z);
                tmp2=-b(i)*D(:,i)';
                sum=sum+tmp1*tmp2';
            end
            grad=(1/22)*sum;         % your code goes here!
            x = x - alpha*grad;
            
            % update iteration count
            iter = iter + 1;
        end
    end</span>

grad_descent_log_logistic.m

<span style="font-family:Times New Roman;">function x = grad_descent_log_logistic(D,b,x0,alpha)
        % Initializations
        x = x0;
        iter = 1;
        max_its = 3000;
        grad = 1;
        m=22;
        while  norm(grad) > 10^-6 && iter < max_its
            sum=0;
            for i=1:22
                z=b(i)*(D(:,i)'*x);
                tmp1=exp(-z)/sigmoid(z);
                tmp2=-b(i)*D(:,i)';
                sum=sum+tmp1*tmp2';
            end
            grad=(1/22)*sum;
            x = x - alpha*grad;
            % update iteration count
            iter = iter + 1;
        end
    end</span>


plot_all.m

<span style="font-family:Times New Roman;">function plot_all(A,b,x,y)
        
        % plot points
        ind = find(b == 1);
        scatter(A(ind,2),A(ind,3),'Linewidth',2,'Markeredgecolor','b','markerFacecolor','none');
        hold on
        ind = find(b == -1);
        scatter(A(ind,2),A(ind,3),'Linewidth',2,'Markeredgecolor','r','markerFacecolor','none');
        hold on
        
        % plot separators
        s =[min(A(:,2)):.01:max(A(:,2))];
        plot (s,(-x(1)-x(2)*s)/x(3),'m','linewidth',2);
        hold on
        
        plot (s,(-y(1)-y(2)*s)/y(3),'k','linewidth',2);
        hold on
        
        set(gcf,'color','w');
        axis([ (min(A(:,2)) - 0.1) (max(A(:,2)) + 0.1) (min(A(:,3)) - 0.1) (max(A(:,3)) + 0.1)])
        box off
        
        % graph info labels
        xlabel('a_1','Fontsize',14)
        ylabel('a_2  ','Fontsize',14)
        set(get(gca,'YLabel'),'Rotation',0)
        
    end</span>

结果图

技术分享
其中黑线为第二种损失函数,彩色线为第一种损失函数。

                                                                                                                     资源----------------代码和数据集见资源

技术分享
                                                                                                                                    中科院大学雁西湖校区


郑重声明:本站内容如果来自互联网及其他传播媒体,其版权均属原媒体及文章作者所有。转载目的在于传递更多信息及用于网络分享,并不代表本站赞同其观点和对其真实性负责,也不构成任何其他建议。