Machine Learning---Logistic回归
下面分别写出这三种形式的损失函数:
下面分别写出这三种损失函数的梯度形式:
其中第一种形式和第三种形式是等价的,推导如下:
Steepest descent
前面章节已经讲过最速下降法的更新公式,如下:
下面将给出代码这样容易理解:
main.m
<span style="font-family:Times New Roman;">[D,b] = load_data(); %%% run exp and log convex logistic regression %%% x0 = randn(3,1); % initial point alpha = 10^-2; % step length x = grad_descent_exp_logistic(D,b,x0,alpha); % Run log convex logistic regression alpha = 10^-1; % step length y = grad_descent_log_logistic(D,b,x0,alpha); %%% plot everything, pts and lines %%% plot_all(D',b,x,y);</span>
<span style="font-family:Times New Roman;"> function [A,b] = load_data() data = load('exp_vs_log_data.mat'); data = data.data; A = data(:,1:3); A = A'; b = data(:,4); end</span>
grad_descent_exp_logistic.m
<span style="font-family:Times New Roman;">function x = grad_descent_exp_logistic(D,b,x0,alpha) % Initializations x = x0; iter = 1; max_its = 3000; grad = 1; m=22; while norm(grad) > 10^-6 && iter < max_its % compute gradient sum=0; for i=1:22 z=b(i)*(D(:,i)'*x); tmp1=exp(-z); tmp2=-b(i)*D(:,i)'; sum=sum+tmp1*tmp2'; end grad=(1/22)*sum; % your code goes here! x = x - alpha*grad; % update iteration count iter = iter + 1; end end</span>
grad_descent_log_logistic.m
<span style="font-family:Times New Roman;">function x = grad_descent_log_logistic(D,b,x0,alpha) % Initializations x = x0; iter = 1; max_its = 3000; grad = 1; m=22; while norm(grad) > 10^-6 && iter < max_its sum=0; for i=1:22 z=b(i)*(D(:,i)'*x); tmp1=exp(-z)/sigmoid(z); tmp2=-b(i)*D(:,i)'; sum=sum+tmp1*tmp2'; end grad=(1/22)*sum; x = x - alpha*grad; % update iteration count iter = iter + 1; end end</span>
<span style="font-family:Times New Roman;">function plot_all(A,b,x,y) % plot points ind = find(b == 1); scatter(A(ind,2),A(ind,3),'Linewidth',2,'Markeredgecolor','b','markerFacecolor','none'); hold on ind = find(b == -1); scatter(A(ind,2),A(ind,3),'Linewidth',2,'Markeredgecolor','r','markerFacecolor','none'); hold on % plot separators s =[min(A(:,2)):.01:max(A(:,2))]; plot (s,(-x(1)-x(2)*s)/x(3),'m','linewidth',2); hold on plot (s,(-y(1)-y(2)*s)/y(3),'k','linewidth',2); hold on set(gcf,'color','w'); axis([ (min(A(:,2)) - 0.1) (max(A(:,2)) + 0.1) (min(A(:,3)) - 0.1) (max(A(:,3)) + 0.1)]) box off % graph info labels xlabel('a_1','Fontsize',14) ylabel('a_2 ','Fontsize',14) set(get(gca,'YLabel'),'Rotation',0) end</span>
结果图
郑重声明:本站内容如果来自互联网及其他传播媒体,其版权均属原媒体及文章作者所有。转载目的在于传递更多信息及用于网络分享,并不代表本站赞同其观点和对其真实性负责,也不构成任何其他建议。