前言:Spring框架对于Java后端程序员来说再熟悉不过了,以前只知道它用的反射实现的,但了解之后才知道有很多巧妙的设计在里面。如果不看Spring的源码,你将会失去一次和大师学习的机会:它的代码规范,设计思想很值得学习。我们程序员大部分人都是野路子,不懂什么叫代码规范。写了一个月的代码,最后还得其他老司机花3天时间重构,相信大部分老司机都很头疼看新手的代码。
先来一份SpringMVC的实现功能,然后通过代码来讲解我们手写SpringMVC需要写哪些东西
大多数的公司对代码的结构都是:
浏览器请求 ==》Controller层—》Service层—》Dao层 ==》数据库
数据请求到后,就会响应返回给浏览器展示
废话不多说,我们进入今天的正题,在Web应用程序设计中,MVC模式已经被广泛使用。SpringMVC以DispatcherServlet为核心,负责协调和组织不同组件以完成请求处理并返回响应的工作,实现了MVC模式。想要实现自己的SpringMVC框架,需要从以下几点入手:
一、了解SpringMVC运行流程及九大组件
二、梳理自己的SpringMVC的设计思路
三、实现自己的SpringMVC框架
一、了解SpringMVC运行流程及九大组件
1、SpringMVC的运行流程
分析一下
⑴ 用户发送请求至前端控制器DispatcherServlet
⑵ DispatcherServlet收到请求调用HandlerMapping处理器映射器。
⑶ 处理器映射器根据请求url找到具体的处理器,生成处理器对象及处理器拦截器(如果有则生成)一并返回给DispatcherServlet。
⑷ DispatcherServlet通过HandlerAdapter处理器适配器调用处理器
⑸ 执行处理器(Controller,也叫后端控制器)。
⑹ Controller执行完成返回ModelAndView
⑺ HandlerAdapter将controller执行结果ModelAndView返回给DispatcherServlet
⑻ DispatcherServlet将ModelAndView传给ViewReslover视图解析器
⑼ ViewReslover解析后返回具体View
⑽ DispatcherServlet对View进行渲染视图(即将模型数据填充至视图中)。
⑾ DispatcherServlet响应用户。
从上面可以看出,DispatcherServlet有接收请求,响应结果,转发等作用。有了DispatcherServlet之后,可以减少组件之间的耦合度。
2、SpringMVC的九大组件(ref:【SpringMVC】9大组件概览)
protected void initStrategies(ApplicationContext context) {
//用于处理上传请求。处理方法是将普通的request包装成MultipartHttpServletRequest,后者可以直接调用getFile方法获取File.
initMultipartResolver(context);
//SpringMVC主要有两个地方用到了Locale:一是ViewResolver视图解析的时候;二是用到国际化资源或者主题的时候。
initLocaleResolver(context);
//用于解析主题。SpringMVC中一个主题对应一个properties文件,里面存放着跟当前主题相关的所有资源、
//如图片、css样式等。SpringMVC的主题也支持国际化,
initThemeResolver(context);
//用来查找Handler的。
initHandlerMappings(context);
//从名字上看,它就是一个适配器。Servlet需要的处理方法的结构却是固定的,都是以request和response为参数的方法。
//如何让固定的Servlet处理方法调用灵活的Handler来进行处理呢?这就是HandlerAdapter要做的事情
initHandlerAdapters(context);
//其它组件都是用来干活的。在干活的过程中难免会出现问题,出问题后怎么办呢?
//这就需要有一个专门的角色对异常情况进行处理,在SpringMVC中就是HandlerExceptionResolver。
initHandlerExceptionResolvers(context);
//有的Handler处理完后并没有设置View也没有设置ViewName,这时就需要从request获取ViewName了,
//如何从request中获取ViewName就是RequestToViewNameTranslator要做的事情了。
initRequestToViewNameTranslator(context);
//ViewResolver用来将String类型的视图名和Locale解析为View类型的视图。
//View是用来渲染页面的,也就是将程序返回的参数填入模板里,生成html(也可能是其它类型)文件。
initViewResolvers(context);
//用来管理FlashMap的,FlashMap主要用在redirect重定向中传递参数。
initFlashMapManager(context);
}
二、梳理SpringMVC的设计思路
本文只实现自己的@Controller、@RequestMapping、@RequestParam注解起作用,其余SpringMVC功能读者可以尝试自己实现。
1、读取配置
从图中可以看出,SpringMVC本质上是一个Servlet,这个 Servlet 继承自 HttpServlet。FrameworkServlet负责初始化SpringMVC的容器,并将Spring容器设置为父容器。因为本文只是实现SpringMVC,对于Spring容器不做过多讲解
为了读取web.xml中的配置,我们用到ServletConfig这个类,它代表当前Servlet在web.xml中的配置信息。通过web.xml中加载我们自己写的MyDispatcherServlet和读取配置文件。
2、初始化阶段
在前面我们提到DispatcherServlet的initStrategies方法会初始化9大组件,但是这里将实现一些SpringMVC的最基本的组件而不是全部,按顺序包括:
1、加载配置文件
2、扫描用户配置包下面所有的类
3、拿到扫描到的类,通过反射机制,实例化。并且放到ioc容器中(Map的键值对 beanName-bean) beanName默认是首字母小写
4、初始化HandlerMapping,这里其实就是把url和method对应起来放在一个k-v的Map中,在运行阶段取出
3、运行阶段
每一次请求将会调用doGet或doPost方法,所以统一运行阶段都放在doDispatch方法里处理,它会根据url请求去HandlerMapping中匹配到对应的Method,然后利用反射机制调用Controller中的url对应的方法,并得到结果返回。按顺序包括以下功能:
异常的拦截
获取请求传入的参数并处理参数
通过初始化好的handlerMapping中拿出url对应的方法名,反射调用
三、实现自己的SpringMVC框架
工程文件及目录
1、首先,新建一个maven项目,在pom.xml中导入以下依赖:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<groupId>com.qw.boot.springmvc</groupId>
<artifactId>springmvc</artifactId>
<version>1.0.0-SNAPSHOT</version>
<packaging>war</packaging>
<modelVersion>4.0.0</modelVersion>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>javax.servlet-api</artifactId>
<version>3.0.1</version>
<scope>provided</scope>
</dependency>
<!--Lombok-->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>
</dependencies>
</project>
2、接着在WEB-INF下创建一个web.xml,如下配置:
<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://java.sun.com/xml/ns/javaee" xmlns:web="http://java.sun.com/xml/ns/javaee/web-app_2_5.xsd"
xsi:schemaLocation="http://java.sun.com/xml/ns/javaee http://java.sun.com/xml/ns/javaee/web-app_3_0.xsd"
version="3.0">
<servlet>
<servlet-name>springmvc</servlet-name>
<servlet-class>com.qw.boot.springmvc.servlet.XXDispatcherServlet</servlet-class>
<init-param>
<param-name>contextConfigLocation</param-name>
<param-value>config.properties</param-value>
</init-param>
<load-on-startup>1</load-on-startup>
</servlet>
<servlet-mapping>
<servlet-name>springmvc</servlet-name>
<url-pattern>/*</url-pattern>
</servlet-mapping>
</web-app>
3、application.properties文件中只是配置要扫描的包到SpringMVC容器中。
scanPackage=com.qw.boot.springmvc.web
4、创建元注解
XXAutowired
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXAutowired {
String value() default "";
}
XXController
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXController {
String value() default "";
}
XXQualifier
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXQualifier {
String value() default "";
}
XXRequestMapping
@Target({ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXRequestMapping {
/**
* 表示访问该方法的url
* @return
*/
String value() default "";
}
XXRequestParam
@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXRequestParam {
/**
* 表示参数的别名,必填
* @return
*/
String value();
}
XXService
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXService {
String value() default "";
}
5、逻辑代码
UserController.java
@XXController
@XXRequestMapping("/user")
public class UserController {
@XXAutowired
private UserService userService;
@XXRequestMapping("/index")
public String index(HttpServletRequest request, HttpServletResponse response,
@XXRequestParam("name")String name) throws IOException {
String res = userService.get(name);
System.out.println(name+"=>"+res);
response.setContentType("application/json;charset=UTF-8");
response.getWriter().write(res);
return "index";
}
@XXRequestMapping("/list")
public String list(HttpServletRequest request,HttpServletResponse response)
throws IOException{
List<User> users = userService.list();
response.setContentType("application/json;charset=UTF-8");
response.getWriter().write(users.toString());
return "list";
}
}
UserService.java
public interface UserService {
String get(String name);
List<User> list();
}
UserServiceImpl.java
@XXService("userService")
public class UserServiceImpl implements UserService {
private static Map<String, User> users = new HashMap<String, User>();
static{
users.put("aa", new User("1","aaa","123456"));
users.put("bb", new User("2","bbb","123456"));
users.put("cc", new User("3","ccc","123456"));
users.put("dd", new User("4","ddd","123456"));
users.put("ee", new User("5","eee","123456"));
}
@Override
public String get(String name) {
User user = users.get(name);
if(user==null){
user = users.get("aa");
}
return user.toString();
}
@Override
public List<User> list() {
List<User> list = new ArrayList<User>();
for(Map.Entry<String, User> entry : users.entrySet()){
list.add(entry.getValue());
}
return list;
}
}
6、核心代码
XXDispatcherServlet.java
public class XXDispatcherServlet extends HttpServlet {
private Properties contextConfig = new Properties();
private List<String> classNames = new ArrayList<String>();
private Map<String, Object> ioc = new HashMap<String, Object>();
private List<Handler> handlerMapping = new ArrayList<Handler>();
private static final long serialVersionUID = -4943120355864715254L;
@Override
public void init(ServletConfig config) throws ServletException {
//load config
doLoadConfig(config.getInitParameter("contextConfigLocation"));
//scan relative class
doScanner(contextConfig.getProperty("scanPackage"));
//init ioc container put relative class to it
doInstance();
//inject dependence
doAutoWired();
//init handlerMapping
initHandlerMapping();
}
private void initHandlerMapping() {
if (ioc.isEmpty()) {
return;
}
for (Map.Entry<String, Object> entry : ioc.entrySet()) {
Class<?> clazz = entry.getValue().getClass();
if (!clazz.isAnnotationPresent(XXController.class)) {
continue;
}
String baseUrl = "";
if (clazz.isAnnotationPresent(XXRequestMapping.class)) {
XXRequestMapping requestMapping = clazz.getAnnotation(XXRequestMapping.class);
baseUrl = requestMapping.value();
}
Method[] methods = clazz.getMethods();
for (Method method : methods) {
if (!method.isAnnotationPresent(XXRequestMapping.class)) {
continue;
}
XXRequestMapping requestMapping = method.getAnnotation(XXRequestMapping.class);
String url = (baseUrl + requestMapping.value()).replaceAll("/+", "/");
Pattern pattern = Pattern.compile(url);
handlerMapping.add(new Handler(pattern, entry.getValue(), method));
System.out.println("mapped:" + url + "=>" + method);
}
}
}
private void doAutoWired() {
if (ioc.isEmpty()) {
return;
}
for (Map.Entry<String, Object> entry : ioc.entrySet()) {
//依赖注入->给加了XXAutowired注解的字段赋值
Field[] fields = entry.getValue().getClass().getDeclaredFields();
for (Field field : fields) {
if (!field.isAnnotationPresent(XXAutowired.class)) {
continue;
}
XXAutowired autowired = field.getAnnotation(XXAutowired.class);
String beanName = autowired.value();
if ("".equals(beanName)) {
beanName = field.getType().getName();
}
field.setAccessible(true);
try {
field.set(entry.getValue(), ioc.get(beanName));
} catch (IllegalAccessException e) {
e.printStackTrace();
continue;
}
}
}
}
private void doInstance() {
if (classNames.isEmpty()) {
return;
}
try {
for (String className : classNames) {
Class<?> clazz = Class.forName(className);
if (clazz.isAnnotationPresent(XXController.class)) {
String beanName = lowerFirstCase(clazz.getSimpleName());
ioc.put(beanName, clazz.newInstance());
} else if (clazz.isAnnotationPresent(XXService.class)) {
XXService service = clazz.getAnnotation(XXService.class);
String beanName = service.value();
if ("".equals(beanName)) {
beanName = lowerFirstCase(clazz.getSimpleName());
}
Object instance = clazz.newInstance();
ioc.put(beanName, instance);
Class<?>[] interfaces = clazz.getInterfaces();
for (Class<?> i : interfaces) {
ioc.put(i.getName(), instance);
}
} else {
continue;
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
private void doScanner(String packageName) {
URL resource =
this.getClass().getClassLoader().getResource("/" + packageName.replaceAll("\\.", "/"));
File classDir = new File(resource.getFile());
for (File classFile : classDir.listFiles()) {
if (classFile.isDirectory()) {
doScanner(packageName + "." + classFile.getName());
} else {
String className = (packageName + "." + classFile.getName()).replace(".class", "");
classNames.add(className);
}
}
}
private void doLoadConfig(String location) {
InputStream input = this.getClass().getClassLoader().getResourceAsStream(location);
try {
contextConfig.load(input);
} catch (IOException e) {
e.printStackTrace();
} finally {
if (input != null) {
try {
input.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse res)
throws ServletException, IOException {
this.doPost(req, res);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse res)
throws ServletException, IOException {
doDispatcher(req, res);
}
public void doDispatcher(HttpServletRequest req, HttpServletResponse res) {
try {
Handler handler = getHandler(req);
if (handler == null) {
res.getWriter().write("404 not found.");
return;
}
Class<?>[] paramTypes = handler.method.getParameterTypes();
Object[] paramValues = new Object[paramTypes.length];
Map<String, String[]> params = req.getParameterMap();
for (Entry<String, String[]> param : params.entrySet()) {
String value = Arrays.toString(param.getValue()).replaceAll("\\[|\\]", "");
if (!handler.paramIndexMapping.containsKey(param.getKey())) {
continue;
}
int index = handler.paramIndexMapping.get(param.getKey());
paramValues[index] = convert(paramTypes[index], value);
}
int reqIndex = handler.paramIndexMapping.get(HttpServletRequest.class.getName());
paramValues[reqIndex] = req;
int resIndex = handler.paramIndexMapping.get(HttpServletResponse.class.getName());
paramValues[resIndex] = res;
handler.method.invoke(handler.controller, paramValues);
} catch (Exception e) {
e.printStackTrace();
}
String url = req.getRequestURI();
String contextPath = req.getContextPath();
url = url.replace(contextPath, "").replaceAll("/+", "/");
}
private Object convert(Class<?> type, String value) {
if (Integer.class == type) {
return Integer.valueOf(value);
}
return value;
}
private String lowerFirstCase(String str) {
char[] chars = str.toCharArray();
chars[0] += 32;
return String.valueOf(chars);
}
private Handler getHandler(HttpServletRequest req) {
if (handlerMapping.isEmpty()) {
return null;
}
String url = req.getRequestURI();
String contextPath = req.getContextPath();
url = url.replace(contextPath, "").replaceAll("/+", "/");
for (Handler handler : handlerMapping) {
Matcher matcher = handler.pattern.matcher(url);
if (!matcher.matches()) {
continue;
}
return handler;
}
return null;
}
private class Handler {
protected Object controller;
protected Method method;
protected Pattern pattern;
protected Map<String, Integer> paramIndexMapping;
protected Handler(Pattern pattern, Object controller, Method method) {
this.pattern = pattern;
this.controller = controller;
this.method = method;
paramIndexMapping = new HashMap<String, Integer>();
putParamIndexMapping(method);
}
private void putParamIndexMapping(Method method) {
Annotation[][] pa = method.getParameterAnnotations();
for (int i = 0; i < pa.length; i++) {
for (Annotation a : pa[i]) {
if (a instanceof XXRequestParam) {
String paramName = ((XXRequestParam) a).value();
if (!"".equals(paramName)) {
paramIndexMapping.put(paramName, i);
}
}
}
}
Class<?>[] paramTypes = method.getParameterTypes();
for (int i = 0; i < paramTypes.length; i++) {
Class<?> type = paramTypes[i];
if (type == HttpServletRequest.class || type == HttpServletResponse.class) {
paramIndexMapping.put(type.getName(), i);
}
}
}
}
}
测试:
1、http://localhost:8080/springmvc/user/index?name=bb
2、http://localhost:8080/springmvc/user/list
3、http://localhost:8080/springmvc/user/detail 出现404