欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

手写spring mvc底层源码

程序员文章站 2022-04-07 21:19:05
...

手写源码之前先来了解几个概念:

1、spring IOC:控制反转,简单来说,就是tomcat在运行得时候创建了一个map,将那些有注解的对象存入这个map中,然后通过注解来获取这些对象供程序使用

2、DI:依赖注入,动态的向某个对象提供它需要的对象

3、DispatcherServlet:Spring MVC底层的具体实现,一般我们会选择默认提供的org.springframework.web.servlet.DispatcherServlet

 

如果想自己实现底层源码,那么就得自己手写DisPatcherServlet

1、首先自己定义几个注解,例如:

@Target(java.lang.annotation.ElementType.FIELD)//只能在类属性上使用

@Retention(RetentionPolicy.RUNTIME)//运行时调用

@Documented

public @interface TestAutowired {

String value() default "";

}

 

@Target(java.lang.annotation.ElementType.TYPE)//只能在类上使用

@Retention(RetentionPolicy.RUNTIME)//运行时调用

@Documented

public @interface TestController {

String value() default "";

}

 

@Target({java.lang.annotation.ElementType.TYPE,java.lang.annotation.ElementType.METHOD})//只能在类、方法上使用

@Retention(RetentionPolicy.RUNTIME)//运行时调用

@Documented

public @interface TestRequestMapping {

String value() default "";

}

 

@Target(java.lang.annotation.ElementType.PARAMETER)//只能在方法上使用

@Retention(RetentionPolicy.RUNTIME)//运行时调用

@Documented

public @interface TestRequestParam {

String value() default "";

}

 

@Target(java.lang.annotation.ElementType.TYPE)//只能在类上使用

@Retention(RetentionPolicy.RUNTIME)//运行时调用

@Documented

public @interface TestService {

String value() default "";

}

 

2、自己手写DispatcherServlet

public class DispatcherServlet extends HttpServlet{

 

List<String> classNames = new ArrayList<String>(); 

Map<String, Object> beans = new HashMap<String, Object>();

Map<String, Object> handlerMap = new HashMap<String, Object>();

 

@Override

public void init(ServletConfig servletConfig) throws ServletException {

//1、扫描

scanPackage("com.wzy");

//2、实例化

doInstance();

//3、自动注入

doAutowired();

//4、方法映射,路径与方法对应

doUrlMapping();

}

 

public void doUrlMapping() {

for(Map.Entry<String, Object> entry : beans.entrySet()) {

Object instance = entry.getValue();

Class<?> clazz = instance.getClass();

if(clazz.isAnnotationPresent(TestController.class)) {

//获取类的上的请求路径

TestRequestMapping classMapping = clazz.getAnnotation(TestRequestMapping.class);

String classPath = classMapping.value();

Method[] methods = clazz.getMethods();

for(Method method : methods) {

if(method.isAnnotationPresent(TestRequestMapping.class)) {

TestRequestMapping methodMapping= method.getAnnotation(TestRequestMapping.class);

String methodPath = methodMapping.value();

handlerMap.put(classPath+methodPath, method);

}

}

}

}

}

 

public void doAutowired() {

for(Map.Entry<String, Object> entry : beans.entrySet()) {

Object instance = entry.getValue();

Class<?> clazz = instance.getClass();

if(clazz.isAnnotationPresent(TestController.class)) {

Field[] fields = clazz.getDeclaredFields();

for(Field field : fields) {

if(field.isAnnotationPresent(TestAutowired.class)) {

TestAutowired testAutowired = field.getAnnotation(TestAutowired.class);

String key = testAutowired.value();

Object obj = beans.get(key);

if(obj != null) {

//打开权限

field.setAccessible(true);

try {

field.set(instance, obj);

} catch (IllegalArgumentException e) {

e.printStackTrace();

} catch (IllegalAccessException e) {

e.printStackTrace();

}

}else {

throw new RuntimeException(obj.getClass().getName()+"was not injection!");

}

}else {

continue;

}

}

}else {

continue;

}

}

}

 

public void doInstance() {

for(String className : classNames) {

String cn = className.replace(".class", "");

try {

Class<?> clazz = Class.forName(cn);

if(clazz.isAnnotationPresent(TestController.class)) {

//创建实例

Object instance = clazz.newInstance();

TestRequestMapping m1 = clazz.getAnnotation(TestRequestMapping.class);

beans.put(m1.value(), instance);

}else if(clazz.isAnnotationPresent(TestService.class)) {

//创建实例

Object instance = clazz.newInstance();

TestService m2 = clazz.getAnnotation(TestService.class);

beans.put(m2.value(), instance);

}else {

continue;

}

} catch (ClassNotFoundException e) {

e.printStackTrace();

} catch (InstantiationException e) {

e.printStackTrace();

} catch (IllegalAccessException e) {

e.printStackTrace();

}

}

}

 

public void scanPackage(String basePackage) {

//查找类的路径

URL url = this.getClass().getClassLoader().getResource(

"/"+basePackage.replaceAll("\\.", "/"));

String fileStr = url.getFile();

File file = new File(fileStr);

String[] filesStr = file.list();

for(String path : filesStr) {

File filePath = new File(fileStr+path);

if(filePath.isDirectory()) {

scanPackage(basePackage+"."+path);

}else {

classNames.add(basePackage+"."+filePath.getName());

}

}

}

 

 

@Override

protected void doGet(HttpServletRequest req, HttpServletResponse resp) 

throws ServletException, IOException {

this.doPost(req, resp);

}

 

@Override

protected void doPost(HttpServletRequest req, HttpServletResponse resp) 

throws ServletException, IOException {

String uri = req.getRequestURI();

String context = req.getContextPath();

String path = uri.replace(context, "");

Method method = (Method) handlerMap.get(path);

TestController instance = (TestController) beans.get("/"+path.split("/")[1]);

Object[] args = hand(req, resp, method);

try {

method.invoke(instance, args);

} catch (IllegalAccessException e) {

e.printStackTrace();

} catch (IllegalArgumentException e) {

e.printStackTrace();

} catch (InvocationTargetException e) {

e.printStackTrace();

}

}

 

private static Object[] hand(HttpServletRequest request, HttpServletResponse response, Method method) {

//拿到当前待执行的方法有哪些参数

Class<?>[] paramClazzs = method.getParameterTypes();

//根据参数的个数,new 一个参数的数组,将方法里的所有参数赋值到args来

Object[] args = new Object[paramClazzs.length];

int arg_i = 0;

int index = 0;

for(Class<?> paramClazz : paramClazzs) {

if(ServletRequest.class.isAssignableFrom(paramClazz)) {

args[arg_i++] = request;

}

if(ServletResponse.class.isAssignableFrom(paramClazz)) {

args[arg_i++] = response;

}

//判断有没有RequestParam注释

Annotation[] paramAns = method.getParameterAnnotations()[index];

if(paramAns.length > 0) {

for(Annotation annotation : paramAns) {

if(TestRequestParam.class.isAssignableFrom(paramAns.getClass())) {

TestRequestParam rp = (TestRequestParam) annotation;

args[arg_i++] = request.getParameter(rp.value());

}

}

}

index ++;

}

return args;

}

 

}

具体讲解下:

scanPackage方法,主要是遍历当前项目,将当前项目中的所有类都存于map中

doInstance方法,主要是遍历上述所有类,将类上有@TestController和@TestService的类进行实例化,然后存于另外的map中,key取@TestRequestMapping后的value值

doAutowired方法,主要是将所有带有@TestController类中的属性上带有@TestAutowired的这些属性从之前的map中获取实例,放进当前的属性中

doUrlMapping方法,将所有@TestController中方法上带有@TestRequestMapping的跟注解后的value路径进行对应

hand,获取当前请求对应的方法中有哪些参数

 

注意:本文没有提到其实在@TestService注解的类下其实也可有@TestAutoWired注解的属性,这个其实可以根据上面的逻辑进行扩展;还有当某个注解后没有value属性值的话,key值该怎么取,可以先获取当前类、属性的名称,然后将首字母小写,将其作为key就可以了

 

附上本人练习项目控制层和服务层:

@TestController

@TestRequestMapping("/test")

public class OrderController {

 

@TestAutowired("orderServiceImpl")

private OrderService orderService;

 

@TestRequestMapping("/get")

public String get(@TestRequestParam("name")String name, @TestRequestParam("age")int age) {

return orderService.query(name, age);

}

 

}

 

@TestService("orderServiceImpl")

public class OrderServiceImpl implements OrderService{

 

public String query(String name, int age) {

return "[name:"+name+",age"+age+"]";

}

 

}

 转发表明来源

 

 

 

相关标签: spring mvc spring