前言:Spring框架对于Java后端程序员来说再熟悉不过了,以前只知道它用的反射实现的,但了解之后才知道有很多巧妙的设计在里面。如果不看Spring的源码,你将会失去一次和大师学习的机会:它的代码规范,设计思想很值得学习。我们程序员大部分人都是野路子,不懂什么叫代码规范。写了一个月的代码,最后还得其他老司机花3天时间重构,相信大部分老司机都很头疼看新手的代码。

先来一份SpringMVC的实现功能,然后通过代码来讲解我们手写SpringMVC需要写哪些东西
大多数的公司对代码的结构都是:
浏览器请求 ==》Controller层—》Service层—》Dao层 ==》数据库

数据请求到后,就会响应返回给浏览器展示

废话不多说,我们进入今天的正题,在Web应用程序设计中,MVC模式已经被广泛使用。SpringMVC以DispatcherServlet为核心,负责协调和组织不同组件以完成请求处理并返回响应的工作,实现了MVC模式。想要实现自己的SpringMVC框架,需要从以下几点入手:

   一、了解SpringMVC运行流程及九大组件

   二、梳理自己的SpringMVC的设计思路

   三、实现自己的SpringMVC框架

一、了解SpringMVC运行流程及九大组件

1、SpringMVC的运行流程

image.png
分析一下

   ⑴ 用户发送请求至前端控制器DispatcherServlet

   ⑵ DispatcherServlet收到请求调用HandlerMapping处理器映射器。

   ⑶ 处理器映射器根据请求url找到具体的处理器,生成处理器对象及处理器拦截器(如果有则生成)一并返回给DispatcherServlet。

   ⑷ DispatcherServlet通过HandlerAdapter处理器适配器调用处理器

   ⑸ 执行处理器(Controller,也叫后端控制器)。

   ⑹ Controller执行完成返回ModelAndView

   ⑺ HandlerAdapter将controller执行结果ModelAndView返回给DispatcherServlet

   ⑻ DispatcherServlet将ModelAndView传给ViewReslover视图解析器

   ⑼ ViewReslover解析后返回具体View

   ⑽ DispatcherServlet对View进行渲染视图(即将模型数据填充至视图中)。

   ⑾ DispatcherServlet响应用户。

从上面可以看出,DispatcherServlet有接收请求,响应结果,转发等作用。有了DispatcherServlet之后,可以减少组件之间的耦合度。

2、SpringMVC的九大组件(ref:【SpringMVC】9大组件概览)

protected void initStrategies(ApplicationContext context) {
    //用于处理上传请求。处理方法是将普通的request包装成MultipartHttpServletRequest,后者可以直接调用getFile方法获取File.
    initMultipartResolver(context);
    //SpringMVC主要有两个地方用到了Locale:一是ViewResolver视图解析的时候;二是用到国际化资源或者主题的时候。
    initLocaleResolver(context); 
    //用于解析主题。SpringMVC中一个主题对应一个properties文件,里面存放着跟当前主题相关的所有资源、
    //如图片、css样式等。SpringMVC的主题也支持国际化, 
    initThemeResolver(context);
    //用来查找Handler的。
    initHandlerMappings(context);
    //从名字上看,它就是一个适配器。Servlet需要的处理方法的结构却是固定的,都是以request和response为参数的方法。
    //如何让固定的Servlet处理方法调用灵活的Handler来进行处理呢?这就是HandlerAdapter要做的事情
    initHandlerAdapters(context);
    //其它组件都是用来干活的。在干活的过程中难免会出现问题,出问题后怎么办呢?
    //这就需要有一个专门的角色对异常情况进行处理,在SpringMVC中就是HandlerExceptionResolver。
    initHandlerExceptionResolvers(context);
    //有的Handler处理完后并没有设置View也没有设置ViewName,这时就需要从request获取ViewName了,
    //如何从request中获取ViewName就是RequestToViewNameTranslator要做的事情了。
    initRequestToViewNameTranslator(context);
    //ViewResolver用来将String类型的视图名和Locale解析为View类型的视图。
    //View是用来渲染页面的,也就是将程序返回的参数填入模板里,生成html(也可能是其它类型)文件。
    initViewResolvers(context);
    //用来管理FlashMap的,FlashMap主要用在redirect重定向中传递参数。
    initFlashMapManager(context); 
}

二、梳理SpringMVC的设计思路

本文只实现自己的@Controller、@RequestMapping、@RequestParam注解起作用,其余SpringMVC功能读者可以尝试自己实现。

1、读取配置

image.png
从图中可以看出,SpringMVC本质上是一个Servlet,这个 Servlet 继承自 HttpServlet。FrameworkServlet负责初始化SpringMVC的容器,并将Spring容器设置为父容器。因为本文只是实现SpringMVC,对于Spring容器不做过多讲解
为了读取web.xml中的配置,我们用到ServletConfig这个类,它代表当前Servlet在web.xml中的配置信息。通过web.xml中加载我们自己写的MyDispatcherServlet和读取配置文件。

2、初始化阶段

在前面我们提到DispatcherServlet的initStrategies方法会初始化9大组件,但是这里将实现一些SpringMVC的最基本的组件而不是全部,按顺序包括:

1、加载配置文件
2、扫描用户配置包下面所有的类
3、拿到扫描到的类,通过反射机制,实例化。并且放到ioc容器中(Map的键值对 beanName-bean) beanName默认是首字母小写
4、初始化HandlerMapping,这里其实就是把url和method对应起来放在一个k-v的Map中,在运行阶段取出

3、运行阶段

每一次请求将会调用doGet或doPost方法,所以统一运行阶段都放在doDispatch方法里处理,它会根据url请求去HandlerMapping中匹配到对应的Method,然后利用反射机制调用Controller中的url对应的方法,并得到结果返回。按顺序包括以下功能:

异常的拦截
获取请求传入的参数并处理参数
通过初始化好的handlerMapping中拿出url对应的方法名,反射调用

三、实现自己的SpringMVC框架

工程文件及目录
image.png

1、首先,新建一个maven项目,在pom.xml中导入以下依赖:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
		 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
		 xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
	<groupId>com.qw.boot.springmvc</groupId>
	<artifactId>springmvc</artifactId>
	<version>1.0.0-SNAPSHOT</version>
	<packaging>war</packaging>

	<modelVersion>4.0.0</modelVersion>

	<properties>
		<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
		<maven.compiler.source>1.8</maven.compiler.source>
		<maven.compiler.target>1.8</maven.compiler.target>
		<java.version>1.8</java.version>
	</properties>

	<dependencies>
		<dependency>
			<groupId>javax.servlet</groupId>
			<artifactId>javax.servlet-api</artifactId>
			<version>3.0.1</version>
			<scope>provided</scope>
		</dependency>
		<!--Lombok-->
		<dependency>
			<groupId>org.projectlombok</groupId>
			<artifactId>lombok</artifactId>
			<scope>provided</scope>
		</dependency>
	</dependencies>
</project>

2、接着在WEB-INF下创建一个web.xml,如下配置:

<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xmlns="http://java.sun.com/xml/ns/javaee" xmlns:web="http://java.sun.com/xml/ns/javaee/web-app_2_5.xsd"
	xsi:schemaLocation="http://java.sun.com/xml/ns/javaee http://java.sun.com/xml/ns/javaee/web-app_3_0.xsd"
	version="3.0">
	<servlet>
		<servlet-name>springmvc</servlet-name>
		<servlet-class>com.qw.boot.springmvc.servlet.XXDispatcherServlet</servlet-class>
		<init-param>
			<param-name>contextConfigLocation</param-name>
			<param-value>config.properties</param-value>
		</init-param>
		<load-on-startup>1</load-on-startup>
	</servlet>
	<servlet-mapping>
		<servlet-name>springmvc</servlet-name>
		<url-pattern>/*</url-pattern>
	</servlet-mapping>

</web-app>

3、application.properties文件中只是配置要扫描的包到SpringMVC容器中。

scanPackage=com.qw.boot.springmvc.web

4、创建元注解

XXAutowired

@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXAutowired {
    String value() default "";
}

XXController

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXController {
    String value() default "";
}

XXQualifier

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXQualifier {
    String value() default "";
}

XXRequestMapping

@Target({ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXRequestMapping {
	/**
     * 表示访问该方法的url
     * @return
     */
    String value() default "";

}

XXRequestParam

@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXRequestParam {
	/**
     * 表示参数的别名,必填
     * @return
     */
    String value();

}

XXService

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXService {
    String value() default "";
}

5、逻辑代码

UserController.java

@XXController
@XXRequestMapping("/user")
public class UserController {

    @XXAutowired
    private UserService userService;

    @XXRequestMapping("/index")
    public String index(HttpServletRequest request, HttpServletResponse response,
                        @XXRequestParam("name")String name) throws IOException {
        String res = userService.get(name);
        System.out.println(name+"=>"+res);
        response.setContentType("application/json;charset=UTF-8");
        response.getWriter().write(res);
        return "index";
    }

    @XXRequestMapping("/list")
    public String list(HttpServletRequest request,HttpServletResponse response)
            throws IOException{
        List<User> users = userService.list();
        response.setContentType("application/json;charset=UTF-8");
        response.getWriter().write(users.toString());
        return "list";
    }

}

UserService.java

public interface UserService {
    String get(String name);
    List<User> list();
}

UserServiceImpl.java

@XXService("userService")
public class UserServiceImpl implements UserService {

    private static Map<String, User> users = new HashMap<String, User>();

    static{
        users.put("aa", new User("1","aaa","123456"));
        users.put("bb", new User("2","bbb","123456"));
        users.put("cc", new User("3","ccc","123456"));
        users.put("dd", new User("4","ddd","123456"));
        users.put("ee", new User("5","eee","123456"));
    }

    @Override
    public String get(String name) {
        User user = users.get(name);
        if(user==null){
            user = users.get("aa");
        }
        return user.toString();
    }

    @Override
    public List<User> list() {
        List<User> list = new ArrayList<User>();
        for(Map.Entry<String, User> entry : users.entrySet()){
            list.add(entry.getValue());
        }
        return list;
    }

}

6、核心代码

XXDispatcherServlet.java

public class XXDispatcherServlet extends HttpServlet {
    private Properties contextConfig = new Properties();
    private List<String> classNames = new ArrayList<String>();
    private Map<String, Object> ioc = new HashMap<String, Object>();
    private List<Handler> handlerMapping = new ArrayList<Handler>();

    private static final long serialVersionUID = -4943120355864715254L;


    @Override
    public void init(ServletConfig config) throws ServletException {
        //load config
        doLoadConfig(config.getInitParameter("contextConfigLocation"));
        //scan relative class
        doScanner(contextConfig.getProperty("scanPackage"));
        //init ioc container put relative class to it
        doInstance();
        //inject dependence
        doAutoWired();
        //init handlerMapping
        initHandlerMapping();
    }

    private void initHandlerMapping() {
        if (ioc.isEmpty()) {
            return;
        }
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            Class<?> clazz = entry.getValue().getClass();
            if (!clazz.isAnnotationPresent(XXController.class)) {
                continue;
            }
            String baseUrl = "";
            if (clazz.isAnnotationPresent(XXRequestMapping.class)) {
                XXRequestMapping requestMapping = clazz.getAnnotation(XXRequestMapping.class);
                baseUrl = requestMapping.value();
            }
            Method[] methods = clazz.getMethods();
            for (Method method : methods) {
                if (!method.isAnnotationPresent(XXRequestMapping.class)) {
                    continue;
                }
                XXRequestMapping requestMapping = method.getAnnotation(XXRequestMapping.class);
                String url = (baseUrl + requestMapping.value()).replaceAll("/+", "/");
                Pattern pattern = Pattern.compile(url);
                handlerMapping.add(new Handler(pattern, entry.getValue(), method));
                System.out.println("mapped:" + url + "=>" + method);
            }
        }
    }

    private void doAutoWired() {
        if (ioc.isEmpty()) {
            return;
        }
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            //依赖注入->给加了XXAutowired注解的字段赋值
            Field[] fields = entry.getValue().getClass().getDeclaredFields();
            for (Field field : fields) {
                if (!field.isAnnotationPresent(XXAutowired.class)) {
                    continue;
                }
                XXAutowired autowired = field.getAnnotation(XXAutowired.class);
                String beanName = autowired.value();
                if ("".equals(beanName)) {
                    beanName = field.getType().getName();
                }
                field.setAccessible(true);
                try {
                    field.set(entry.getValue(), ioc.get(beanName));
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                    continue;
                }
            }
        }
    }

    private void doInstance() {
        if (classNames.isEmpty()) {
            return;
        }
        try {
            for (String className : classNames) {
                Class<?> clazz = Class.forName(className);
                if (clazz.isAnnotationPresent(XXController.class)) {
                    String beanName = lowerFirstCase(clazz.getSimpleName());
                    ioc.put(beanName, clazz.newInstance());
                } else if (clazz.isAnnotationPresent(XXService.class)) {

                    XXService service = clazz.getAnnotation(XXService.class);
                    String beanName = service.value();
                    if ("".equals(beanName)) {
                        beanName = lowerFirstCase(clazz.getSimpleName());
                    }
                    Object instance = clazz.newInstance();
                    ioc.put(beanName, instance);
                    Class<?>[] interfaces = clazz.getInterfaces();
                    for (Class<?> i : interfaces) {
                        ioc.put(i.getName(), instance);
                    }
                } else {
                    continue;
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void doScanner(String packageName) {
        URL resource =
                this.getClass().getClassLoader().getResource("/" + packageName.replaceAll("\\.", "/"));
        File classDir = new File(resource.getFile());
        for (File classFile : classDir.listFiles()) {
            if (classFile.isDirectory()) {
                doScanner(packageName + "." + classFile.getName());
            } else {
                String className = (packageName + "." + classFile.getName()).replace(".class", "");
                classNames.add(className);
            }
        }
    }

    private void doLoadConfig(String location) {
        InputStream input = this.getClass().getClassLoader().getResourceAsStream(location);
        try {
            contextConfig.load(input);
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (input != null) {
                try {
                    input.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse res)
            throws ServletException, IOException {
        this.doPost(req, res);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse res)
            throws ServletException, IOException {
        doDispatcher(req, res);
    }

    public void doDispatcher(HttpServletRequest req, HttpServletResponse res) {
        try {
            Handler handler = getHandler(req);
            if (handler == null) {
                res.getWriter().write("404 not found.");
                return;
            }
            Class<?>[] paramTypes = handler.method.getParameterTypes();
            Object[] paramValues = new Object[paramTypes.length];
            Map<String, String[]> params = req.getParameterMap();
            for (Entry<String, String[]> param : params.entrySet()) {
                String value = Arrays.toString(param.getValue()).replaceAll("\\[|\\]", "");
                if (!handler.paramIndexMapping.containsKey(param.getKey())) {
                    continue;
                }
                int index = handler.paramIndexMapping.get(param.getKey());
                paramValues[index] = convert(paramTypes[index], value);
            }
            int reqIndex = handler.paramIndexMapping.get(HttpServletRequest.class.getName());
            paramValues[reqIndex] = req;
            int resIndex = handler.paramIndexMapping.get(HttpServletResponse.class.getName());
            paramValues[resIndex] = res;
            handler.method.invoke(handler.controller, paramValues);
        } catch (Exception e) {
            e.printStackTrace();
        }
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        url = url.replace(contextPath, "").replaceAll("/+", "/");

    }

    private Object convert(Class<?> type, String value) {
        if (Integer.class == type) {
            return Integer.valueOf(value);
        }
        return value;
    }

    private String lowerFirstCase(String str) {
        char[] chars = str.toCharArray();
        chars[0] += 32;
        return String.valueOf(chars);
    }

    private Handler getHandler(HttpServletRequest req) {
        if (handlerMapping.isEmpty()) {
            return null;
        }
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        url = url.replace(contextPath, "").replaceAll("/+", "/");
        for (Handler handler : handlerMapping) {
            Matcher matcher = handler.pattern.matcher(url);
            if (!matcher.matches()) {
                continue;
            }
            return handler;
        }
        return null;
    }

    private class Handler {
        protected Object controller;
        protected Method method;
        protected Pattern pattern;
        protected Map<String, Integer> paramIndexMapping;

        protected Handler(Pattern pattern, Object controller, Method method) {
            this.pattern = pattern;
            this.controller = controller;
            this.method = method;
            paramIndexMapping = new HashMap<String, Integer>();
            putParamIndexMapping(method);
        }

        private void putParamIndexMapping(Method method) {
            Annotation[][] pa = method.getParameterAnnotations();
            for (int i = 0; i < pa.length; i++) {
                for (Annotation a : pa[i]) {
                    if (a instanceof XXRequestParam) {
                        String paramName = ((XXRequestParam) a).value();
                        if (!"".equals(paramName)) {
                            paramIndexMapping.put(paramName, i);
                        }
                    }
                }
            }
            Class<?>[] paramTypes = method.getParameterTypes();
            for (int i = 0; i < paramTypes.length; i++) {
                Class<?> type = paramTypes[i];
                if (type == HttpServletRequest.class || type == HttpServletResponse.class) {
                    paramIndexMapping.put(type.getName(), i);
                }
            }
        }
    }
}

测试:

1、http://localhost:8080/springmvc/user/index?name=bb
image.png

2、http://localhost:8080/springmvc/user/list
image.png
3、http://localhost:8080/springmvc/user/detail 出现404
image.png