当前位置: 首页 > 编程日记 > 正文

【cs229-Lecture2】Linear Regression with One Variable (Week 1)(含测试数据和源码)

从Ⅱ到Ⅳ都在讲的是线性回归,其中第Ⅱ章讲得是简单线性回归(simple linear regression, SLR)(单变量),第Ⅲ章讲的是线代基础,第Ⅳ章讲的是多元回归(大于一个自变量)。

本文的目的主要是对Ⅱ章中出现的一些算法进行实现,适合的人群为已经看完本章节Stanford课程的学者。本人只是一名初学者,尽可能以白话的方式来说明问题。不足之处,还请指正。

在开始讨论具体步骤之前,首先给出简要的思维路线:

1.拥有一个点集,为了得到一条最佳拟合的直线;

2.通过“最小二乘法”来衡量拟合程度,得到代价方程;

3.利用“梯度下降算法”使得代价方程取得极小值点;



首先,介绍几个概念:

回归在数学上来说是给定一个点集,能够用一条曲线去拟合之。如果这个曲线是一条直线,那就被称为线性回归;如果曲线是一条二次曲线,就被称为二次回归,回归还有很多的变种,如locally weighted回归,logistic回归等等。

课程中得到的h就是线性回归方程:

image


下面,首先来介绍一下单变量的线性回归:

问题是这样的:给定一个点集,找出一条直线去拟合,要求拟合的效果达到最佳(最佳拟合)。

既然是直线,我们先假设直线的方程为:image

如图:image

点集有了,直线方程有了,接下来,我们要做的就是计算出imageimage,使得拟合效果达到最佳(最佳拟合)。

那么,拟合效果的评判标准是什么呢?换句话说,我们需要知道一种对拟合效果的度量。

在这里,我们提出“最小二乘法”:(以下摘自wiki)

最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。

利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小。

对于“最小二乘法”就不再展开讨论,只要知道他是一个度量标准,我们可以用它来评判计算出的直线方程是否达到了最佳拟合就够了。

那么,回到问题上来,在单变量的线性回归中,这个拟合效果的表达式是利用最小二乘法将未知量残差平方和最小化

image

结合课程,定义了一个成本函数:

image

其实,到这里,要是把点集的具体数值代入到成本函数中,就已经完全抽象出了一个高等数学问题(解一个二元函数的最小值问题)。

image

其中,a,b,c,d,e,f均为已知。

课程中介绍了一种叫“Gradient descent”的方法——梯度下降算法

image

两张图说明算法的基本思想:

imageimage

image

所谓梯度下降算法(一种求局部最优解的方法),举个例子就好比你现在在一座山上,你想要尽快地到达山底(极小值点),这是一个下降的过程,这里就涉及到了两个问题:1)你下山的时候,跨多大的步子(当然,肯定不是越大越好,因为有一种可能就是你一步跨地太大,正好错过了极小的位置);2)你朝哪个方向跨步(注意,这个方向是不断变化的,你每到一个新的位置,要判断一下下一步朝那个方向走才是最好的,但是有一点可以肯定的是,要想尽快到达最低点,应从最陡的地方下山)。

那么,什么时候算是你到了一个极小点呢,显然,当你所处的位置发生的变化不断减小,直至收敛于某一位置,就说明那个位置就是一个极小值点。

so,我们来看image的变化,则我们需要让imageimage求偏导,倒数代表变化率。也就是要朝着对陡的地方下山(因为沿着最陡显然比较快),就得到了image的变化情况:image

image

image

简化之后:

image

步长不宜过大或过小

image

梯度下降法是按下面的流程进行的:(转自:http://blog.sina.com.cn/s/blog_62339a2401015jyq.html)

1)首先对θ赋值,这个值可以是随机的,也可以让θ是一个全零的向量。

2)改变θ的值,使得J(θ)按梯度下降的方向进行减少。

为了方便大家的理解,首先给出单变量的例子:

eg:求image的最小值。(注:image

image

java代码如下:

·

package OneVariable;public class OneVariable{public static void main(String[] args){double e=0.00001;//定义迭代精度double alpha=0.5;//定义迭代步长double x=0;            //初始化xdouble y0=2*x*x+3*x+1;//与初始化x对应的y值double y1=0;//定义变量,用于保存当前值while (true){x=x-alpha*(4.0*x+3.0);y1=2*x*x+3*x+1;if (Math.abs(y1-y0)<e)//如果2次迭代的结果变化很小,结束迭代
        {break;}y0=y1;//更新迭代的结果
    }System.out.println("Min(f(x))="+y0);System.out.println("minx="+x);}
}//输出
Min(f(x))=1.0
minx=-1.5

两个变量的时候,为了更清楚,给出下面的图:

image

这是一个表示参数θ与误差函数J(θ)的关系图,红色的部分是表示J(θ)有着比较高的取值,我们需要的是,能够让J(θ)的值尽量的低。也就是深蓝色的部分。θ0,θ1表示θ向量的两个维度。

在上面提到梯度下降法的第一步是给θ给一个初值,假设随机给的初值是在图上的十字点。

然后我们将θ按照梯度下降的方向进行调整,就会使得J(θ)往更低的方向进行变化,如图所示,算法的结束将是在θ下降到无法继续下降为止。

image

当然,可能梯度下降的最终点并非是全局最小点,可能是一个局部最小点,可能是下面的情况:

image

上面这张图就是描述的一个局部最小点,这是我们重新选择了一个初始点得到的,看来我们这个算法将会在很大的程度上被初始点的选择影响而陷入局部最小点

一个很重要的地方值得注意的是,梯度是有方向的,对于一个向量θ,每一维分量θi都可以求出一个梯度的方向,我们就可以找到一个整体的方向,在变化的时候,我们就朝着下降最多的方向进行变化就可以达到一个最小点,不管它是局部的还是全局的。


理论的知识就讲到这,下面,我们就用java去实现这个算法:

梯度下降有两种:批量梯度下降和随机梯度下降。详见:http://blog.csdn.net/lilyth_lilyth/article/details/8973972

测试数据就用课后题中的数据(ex1data1.txt),用matlab打开作图得到:

image

首先说明:以下源码是不正确的,具体为什么不正确我还没搞清楚!非常希望各位高手能够指正!

测试数据及源码下载:http://pan.baidu.com/s/1mgiIVm4

OneVariable.java
 1 package OneVariableVersion;
 2 
 3 import java.io.IOException;
 4 import java.util.List;
 5 
 6 
 7 /**
 8  * Linear Regression with One Variable
 9  * @author XBW
10  * @date 2014年8月17日
11  */
12 
13 public class OneVariable{
14     public static final Double e=0.00001;
15     public static List<Data> DS;
16     public static Double step;
17     public static Double m;
18     
19     /**
20      * 计算当前参数是否符合
21      * @param ans
22      * @param datalist
23      * @return
24      */
25     public static Ans calc(Ans ans){
26         Double costfun;
27         do{
28             costfun=calcAccuracy(ans);
29             ans=update(ans);
30             step*=0.3;
31         }while(Math.abs(costfun-calcAccuracy(ans))>e);
32         ans.ifConvergence=true;
33         return ans;
34     }
35     
36     /**
37      * 判断当前ans是否满足精度,y=t0+t1*x
38      * @param ans
39      * @return
40      */
41     public static Double calcAccuracy(Ans ans){
42         Double cost=0.0;
43         Double tmp;
44         for(int i=0;i<m;i++){
45             tmp=DS.get(i).y-(ans.theta[0]*DS.get(i).x[0]+ans.theta[1]*DS.get(i).x[1]);
46             cost+=tmp*tmp;
47         }
48         cost/=(2*m);
49         return cost;
50     }
51     
52     /**
53      * 更新ans
54      * @param ans,学习速率为step,m为数据量
55      * @return
56      */
57     public static Ans update(Ans ans){
58         Double[] tmp=new Double[100] ;
59         for(int i=0;i<2;i++){
60             tmp[i]=ans.theta[i]-step*fun(ans,i);
61         }
62         for(int i=0;i<2;i++){
63             ans.theta[i]=tmp[i];
64         }
65         return ans;
66     }
67     
68     /**
69      * 计算偏导
70      * @return
71      */
72     public static Double fun(Ans ans,int xi){
73         Double ret = 0.0;
74         for(int i=0;i<m;i++){
75             ret+=(ans.theta[0]*DS.get(i).x[0]+ans.theta[1]*DS.get(i).x[1]-DS.get(i).y)*DS.get(i).x[xi];
76         }
77         ret/=m;
78         return ret;        
79     }
80     
81     public static void main(String[] args) throws IOException{
82         DS=new DataSet().ds;
83         step=1.0;        
84         m=(double)DS.size();
85         
86         
87         Double[] theta={0.0,0.0};                     //初始设定theta0=0,theta1=0
88         Ans ans=new Ans(theta,false);
89         Ans answer;
90         answer=calc(ans);
91         System.out.println("theta1= "+answer.theta[0]+"      theta2="+answer.theta[1]);
92     }
93 }

DataSet.java

 1 package OneVariableVersion;
 2 
 3 import java.io.BufferedReader;
 4 import java.io.File;
 5 import java.io.FileReader;
 6 import java.io.IOException;
 7 import java.util.ArrayList;
 8 import java.util.List;
 9 
10 
11 /**
12  * 数据处理
13  * @author XBW
14  * @date 2014年8月17日
15  */
16 
17 public class DataSet{
18     String defaultpath="D:\\MachineLearning\\StanfordbyAndrewNg\\II.LinearRegressionwithOneVariable(Week1)\\homework\\ex1data1.txt";
19     
20     List<Data> ds=new ArrayList<Data>();
21     
22     public DataSet() throws IOException{
23         File dataset=new File(defaultpath);
24         BufferedReader br = new BufferedReader(new FileReader(dataset));
25         String tsing;
26         while((tsing=br.readLine())!=null){
27             String[] dlist=tsing.split(",");
28             Data dtmp=new Data(Double.parseDouble(dlist[0]),Double.parseDouble(dlist[1]));
29             this.ds.add(dtmp);
30         }
31         br.close();
32     }
33 }

Ans.java

 1 package OneVariableVersion;
 2 
 3 /**
 4  * 保存结果,y=t0+t1*x
 5  * @author XBW
 6  * @date 2014年8月17日
 7  */
 8 
 9 public class Ans {
10     Double[] theta;
11     boolean ifConvergence;
12     
13     public Ans(Double[] tmp,boolean ifCon){
14         this.theta=tmp;
15         this.ifConvergence=ifCon;
16     }
17 }

Data.java

 1 package OneVariableVersion;
 2 
 3 
 4 /**
 5  * 一条数据
 6  * @author XBW
 7  * @date 2014年8月17日
 8  */
 9 public class Data {
10     Double[] x=new Double[2];
11     Double y;
12     
13     public Data(Double xtmp,Double ytmp){
14         this.x[0]=1.0;
15         this.x[1]=xtmp;
16         this.y=ytmp;
17     }
18 }

总结:写代码的时候有几个讲究:

  1. 步长是否需要动态变化,按照coursera公开课上讲的是不必要动态改变的,因为偏导数会越来越小,但在实际情况下,按照一定的比值缩小或者自己定义一种缩小的方式可能是有必要的,所以具体情况具体分析;
  2. 初始步长的设定也是很重要的,大了就不会得到结果,因为发散了;步长越大,下降速率越快,但是也会导致震荡,所以,还是哪句话:具体问题具体分析;

转载于:https://www.cnblogs.com/XBWer/p/3912792.html

相关文章:

101种设计模式

https://sourcemaking.com/design-patterns-and-tips

(C++)1032 挖掘机技术哪家强

笔记&#xff1a;考虑到输入只有一所学校&#xff0c;且得分还为0的特殊情况&#xff0c;应该把high初始化为1 #include<cstdio> #include<cmath> #include<cstring> #include<algorithm> using namespace std;int grds[100010] {0};int main(){int …

数据库打开报错: 值不能为空

报错信息如下&#xff1a; 数据库客户端打不开 解决方案&#xff1a; 找到下面的目录C:\Users\<username>\AppData\Local\Temp 创建一个空文件夹 名称是&#xff1a; 2 重新打开数据库转载于:https://www.cnblogs.com/Mander/p/3921251.html

学习 JavaScript (四)核心概念:操作符

JavaScript 的核心概念主要由语法、变量、数据类型、操作符、语句、函数组成&#xff0c;前面三个上一篇文章已经讲解完了。后面三个内容超级多&#xff0c;这篇文章主要讲解的是操作符。 操作符 什么叫做操作符&#xff1f; 这是一种工具&#xff0c;帮助我们操作字符串、数字…

(C++)1011 World Cup Betting

笔记&#xff1a;我觉得这一次的代码很优雅 #include<cstdio> #include<cmath> #include<cstring> #include<algorithm> using namespace std;int maxPro(double a[3]){//返回值最大的下标 int idx0,max_pro0;for(int i0;i<3;i){if(a[i]>max_pr…

Ext学习-前后交互模式介绍

在前后台交互模式的介绍中&#xff0c;实际上就是Store中Proxy相关的内容&#xff0c;比如Ajax提交。 所以详细的文档请参考&#xff1a; Ext学习-基础概念&#xff0c;核心思想介绍 中关于数据模型和MVC结构部分。 作者&#xff1a;sdjnzqr 出处&#xff1a;http://www.cnblog…

让你彻底明白什么叫游戏引擎(1)

在阅读各种游戏介绍的时候我们常常会碰见“引擎”&#xff08;Engine&#xff09;这个单词&#xff0c;引擎在游戏中究竟起着什么样的作用&#xff1f;它的进化对于游戏的发展产生了哪些影响&#xff1f;希望下面这篇文章能为大家释疑。以希望能够帮助一些刚进入游戏行业的小菜…

185.dubbo 后台管理系统

2019独角兽企业重金招聘Python工程师标准>>> 1. 效果及目的 效果&#xff1a; 目的&#xff1a;查看 管理服务 2. 启动要求 &#xff08;1&#xff09;项目是dubbo &#xff08;2&#xff09;jdk 1.7 (3) dubbo的war要与zookeeper在同一台服务上 3. 安装zookeeper 要…

(C++)1027 打印沙漏

笔记&#xff1a;星号右边的空格不用打印 #include<cstdio> #include<cmath> #include<cstring> #include<algorithm> using namespace std;int main(){int n;char c;scanf("%d %c",&n,&c);int clock[23];int col;for(int i1;i<…

黑帽大会2014:10个酷炫的黑客工具

http://www.csdn.net/article/2014-08-21/2821304 用于恶意软件分析的Maltrieve 安全研究人员使用Maltrieve工具收集服务器上的恶意软件。通过这个开源工具&#xff0c;恶意软件分析人员可以通过分析URL链表和已知的托管地址获得最新鲜的样本。 Kyle Maxwell是VeriSign的一名威…

C#无符号右移

/// <summary>/// 无符号右移&#xff0c;与JS中的>>>等价/// </summary>/// <param name"x">要移位的数</param>/// <param name"y">移位数</param>/// <returns></returns>public static int …

1027 Colors in Mars

笔记&#xff1a;本题属于进制转换&#xff0c;但是考察的重点不在除基取余上&#xff0c;因为转化得到的数只有两位&#xff0c;很容易得到每位上面应该是什么&#xff0c;但是和其他题不同的地方在于&#xff0c;每位可填的不见得是0~9&#xff0c;还包括ABC&#xff0c;这就…

json对象和json字符串转换方法

在WEB数据传输过程中&#xff0c;json是以文本&#xff0c;即字符串的轻量级形式传递的&#xff0c;而客户端一般用JS操作的是接收到的JSON对象&#xff0c;所以&#xff0c;JSON对象和JSON字符串之间的相互转换、JSON数据的解析是关键。 先明确2个概念例如&#xff1a; JSON字…

python-docx操作

import docx# 读取docx文档内容def readWord():doc docx.Document(demo.docx)fullText []for para in doc.paragraphs:fullText.append( para.text)print(\n . join(fullText))readWord()官方API&#xff1a;https://python-docx.readthedocs.io/en/latest/index.html ;转载…

javascript中FORM表单的submit()方法经验教训.

author songfeng 因为JS内对象的方法实际上是存储语句的一个类似于指针的东西. 其指向了内存的一个位置, 也就是其函数的位置,当然也可以让其指向一个变量值. var foo new Object();foo.bar function() {} //现在foo.bar就是指向了这个函数的内存位置.foo.bar &q…

1058 A+B in Hogwarts

笔记&#xff1a;和乙级的在霍格沃兹找零钱不同&#xff0c;这里不需要判断给出的两个数的大小&#xff0c;也没必要先都换算成最小的单位&#xff0c;可以直接从最低位开始加&#xff0c;如果超过该位的范围&#xff0c;则向上一位进一即可。 #include<cstdio> #includ…

DDD领域驱动设计之聚合、实体、值对象

关于具体需求&#xff0c;请看前面的博文&#xff1a;DDD领域驱动设计实践篇之如何提取模型&#xff0c;下面是具体的实体、聚合、值对象的代码&#xff0c;不想多说什么是实体、聚合等概念&#xff0c;相信理论的东西大家已经知晓了。本人对DDD表示好奇&#xff0c;没有在真正…

C#用 SendKyes 结合 Process 或 API FindWindow、SendMessage(PostMessage) 等控制外部程序

Win32 平台是 消息驱动模式.Net 框架是 事件驱动模式标题所指的 “控制外部程序”&#xff0c;外部程序是指与本程序无内在相关性的另外一个程序 基于上面提到的&#xff0c;对于.NET的winform程序&#xff0c;在默认情况下&#xff08;即未对接收消息的事件做自定义处理&#…

springMVC swagger2

参考地址&#xff1a;https://www.cnblogs.com/exmyth/p/7183753.html https://blog.csdn.net/programmer_sean/article/details/72236948 1. maven 依赖 <dependency><groupId>io.springfox</groupId><artifactId>springfox-swagger2</artifactId&…

1061 Dating

笔记&#xff1a; 第一个输出根据的是大写字母 第二个输出根据的是0-9andA-N 第三个输出根据的是大写字母和小写字母 知道范围便方便确定边界 两两比对时&#xff0c;先遍历一个字符串&#xff0c;遇到在范围内的字符&#xff0c;看其和第二个字符串同位置的字符是否相等 …

PA 项目创建任务

---- 创建任务 DECLAREp_project_id NUMBER : 155233;p_task_number VARCHAR2(240) : CXYTEST0001;p_task_name VARCHAR2(240) : 接口测试CXYTEST0001;p_task_description VARCHAR2(240) : TASKCXYTEST0001;p_scheduled_start_date DAT…

SSM登陆拦截器实现

首先在springmvc中配置拦截器 <!-- 配置拦截器 --><mvc:interceptors><mvc:interceptor><!-- 拦截所有mvc控制器 --><mvc:mapping path"/**"/><!-- mvc:exclude-mapping是另外一种拦截&#xff0c;它可以在你后来的测试中对某个页面…

AGG 学习笔记

我了解的&#xff21;&#xff27;&#xff27;的总体结构按照文件大致分为&#xff1a;   &#xff11;&#xff09;基本定义&#xff08;config,basics....)&#xff1b;   &#xff12;&#xff09;基本操作、类型&#xff08;主要供&#xff21;&#xff27;&#xff2…

1073 Scientific Notation

笔记&#xff1a;这是我迄今为止写过的最复杂的字符串处理算法题。 收获&#xff1a;分而治之&#xff0c;想不清楚就自己设计测试用例和结果。列举然后归类。 以下是程序流程图 #include<cstdio> #include<cmath> #include<cstring> #include<algorith…

几个笔试题目总结

1、阿里某个笔试题&#xff0c;两个字符串text&#xff0c;query&#xff0c;找到text中包含的最长的query的字串&#xff1a; public static String subStr(String text, String query) {if (text ! null && query ! null) {int length query.length();for (int i 0…

baidu mp3竟然还加密,太扯了

baidu mp3竟然还加密&#xff0c;太扯了 public class BaiduHelper { static int F 0; static string I "", J ""; static string O ""; static string E ""; static int[] K new int[1000…

Ubuntu 之linux与windows互传文件

Windows系统下与linux传输文件 windows环境下&#xff0c;windows传出数据到linux下 确保ubuntu安装了ssh服务端。如果没有安装&#xff0c;使用以下命令安装&#xff1a; sudo aptget install ssh service sshd restart 2.windows下下载pscp.exe软件从PuTTY官方网站下载pscp.e…

1048 数字加密 --非满分

16/20 非满分&#xff0c;待来日复习双指针再分析原因 #include<cstdio> #include<cmath> #include<cstring> #include<algorithm> #include<bits/stdc.h> using namespace std;void reverStr(char str[]){int len strlen(str);for(int i0;i&l…

端到端对话模型新突破!Facebook发布大规模个性化对话数据库

作者&#xff5c;Pierre-Emmanuel Mazare 等译者&#xff5c;郝毅编辑&#xff5c;Debra出处丨 AI 前线AI 前线导读&#xff1a;聊天机器人是目前非常流行的一种人工智能系统。目前大部分聊天机器人的衔接性都不是很好&#xff0c;尤其是在没有主动的重调优策略下训练出的端到端…

上传文件大小的配置Webcong

修改Webcong文件:<system.web><httpRuntime maxRequestLength"40690" useFullyQualifiedRedirectUrl"true" executionTimeout"6000" useFullyQualifiedRedirectUrl"false" minFreeThreads"…