JAVA学习笔记--手写简单的Mybatis

it2026-01-07  10

1. 传统的JDBC模式

在没有使用ORM框架时,我们基本都是通过JDBC进行数据库的操作,一般的逻辑代码如下:

public static void main(String[] args) { Connection connection = null; PreparedStatement preparedStatement = null; ResultSet resultSet = null; try { // 加载数据库驱动 Class.forName("com.mysql.jdbc.Driver"); // 通过驱动管理类获取数据库链接 connection = DriverManager.getConnection("jdbc:mysql://localhost:3306/mybatis? characterEncoding=utf-8", "root", "root"); // 定义sql语句?表示占位符 String sql = "select * from user where username = ?"; // 获取预处理statement preparedStatement = connection.prepareStatement(sql); // 设置参数,第一个参数为sql语句中参数的序号(从1开始),第二个参数为设置的参数值 preparedStatement.setString(1, "tom"); // 向数据库发出sql执行查询,查询出结果集 resultSet = preparedStatement.executeQuery(); // 遍历查询结果集 while (resultSet.next()) { int id = resultSet.getInt("id"); String username = resultSet.getString("username"); // 封装User user.setId(id); user.setUsername(username); } System.out.println(user); } } catch (Exception e) { e.printStackTrace(); } finally { // 释放资源 if (resultSet != null) { try { resultSet.close(); } catch (SQLException e) { e.printStackTrace(); } } if (preparedStatement != null) { try { preparedStatement.close(); } catch (SQLException e) { e.printStackTrace(); } } if (connection != null) { try { connection.close(); } catch (SQLException e) { e.printStackTrace(); } } }

通过对JDBC的代码进行分析,可以发现存在以下问题: 原始jdbc开发存在的问题如下:

数据库连接创建、释放频繁造成系统资源浪费,从而影响系统性能;Sql语句在代码中硬编码,造成代码不易维护,实际应用中sql变化的可能较大,sql变动需要改变java代码。使用preparedStatement向占位符传参存在硬编码,因为sql语句的where条件不一定,可能多也可能少,修改sql还要修改代码,系统不易维护。对结果集解析存在硬编码(查询列名),sql变化导致解析代码变化,系统不易维护,如果能将数据库记录封装成pojo对象解析比较方便。

针对上述的问题,由于硬编码可以通过配置文件解决,所以我们可以想到如下的解决思路:

使用数据库连接池初始化连接资源;将sql语句抽取到xml配置文件中;使用反射、内省等底层技术,自动将实体与表进行属性与字段的自动映射;

根据上述的思路进行自定义的Mybatis框架编写。

2. 客户端开发

首先客户端要提供数据库的连接信息以及SQL的配置信息,根据设计思路,通过配置文件来解决,所以在resources目录下创建sqlMapConfig.xml文件,内容如下:

<configuration> <!-- 数据库配置信息--> <dataSource> <property name="driver" value="com.mysql.jdbc.Driver"></property> <property name="jdbcUrl" value="jdbc:mysql:///test"></property> <property name="username" value="root"></property> <property name="password" value="123456"></property> </dataSource> <mappers> <mapper resource="UserMapper.xml"></mapper> </mappers> </configuration>

在这个文件中配置了要加载那些mapper.xml,这样只需要加载一次就完成了所有的加载。 在同个目录下创建映射配置文件UserMapper.xml,并定义了简单的增删改查的SQL。内容如下:

<mapper namespace="com.ormtest.mapper.UserMapper"> <select id="selectAll" resultType="com.ormtest.pojo.User"> select * from user </select> <select id="selectOne" resultType="com.ormtest.pojo.User" parameterType="com.ormtest.pojo.User"> select * from user where id = #{id} and username = #{username} </select> <!--添加用户--> <insert id="insertUser" parameterType="com.ormtest.pojo.User" > insert into user values(#{id},#{username},#{password}) </insert> <!--修改--> <update id="updateUser" parameterType="com.ormtest.pojo.User"> update user set username = #{username} where id = #{id} </update> <!--删除--> <delete id="deleteUser" parameterType="java.lang.Integer"> delete from user where id = #{id} </delete> </mapper>

同时也要生成对应的POJO对象,并定义Dao层接口,代码简单,此处就不粘贴了。 到此为止,客户的代码就编写完成了,等完成框架的编写之后,就可以进行测试了。

3. 自定义框架开发

开发的时候,我们根据逻辑流程进行。

首先我们需要一个类来接受和加载核心配置文件:

public class Resources { // 将xml配置文件加载成为字节流 public static InputStream getResourceAsSteam(String path) { InputStream resourceAsStream = Resources.class.getClassLoader().getResourceAsStream(path); return resourceAsStream; } }

核心配置文件加载完成之后,就需要对字节流进行解析,所以声明一个解析xml的类,最终这个类要返回一个SqlSessionFactory的对象:

public class SqlSessionFactoryBuilder { // 使用构造器模式,将复杂对象进行逐步构建 public SqlSessionFactory build(InputStream in) throws Exception { // 使用dom4j读取字节流(也就是核心配置文件)的内容 // 封装成一个configuration对象 XMLConfigBuilder xmlConfigBuilder = new XMLConfigBuilder(); Configuration configuration = xmlConfigBuilder.parseConfig(in); // 根据配置信息创建sqlSession的工厂 SqlSessionFactory sqlSessionFactory = new DefaultSqlSessionFactory(configuration); return sqlSessionFactory; } }

在这个解析的过程中,声明了一些解析类,比如XMLConfigBuilder进行核心配置文件的解析,里面还嵌套调用了mapper.xml的解析类,代码如下:

public class XMLConfigBuilder { private Configuration configuration; private static final String DRIVER = "driver"; private static final String JDBCURL = "jdbcUrl"; private static final String USERNAME = "username"; private static final String PASSWORD = "password"; public XMLConfigBuilder() { this.configuration = new Configuration(); } /** * 通过dom4j进行字节流解析 * @param inputStream * @return */ public Configuration parseConfig(InputStream inputStream) throws DocumentException, PropertyVetoException { // 借助dom4j,进行解析,得到整个的文档对象 Document document = new SAXReader().read(inputStream); // 得到根对象,即configuration标签 Element rootElement = document.getRootElement(); // 获取property标签,进行数据库连接信息的解析加载 List<Element> elementList = rootElement.selectNodes("//property"); // 借助properties对象,进行属性保存 Properties properties = new Properties(); for (Element element : elementList) { String name = element.attributeValue("name"); String value = element.attributeValue("value"); properties.setProperty(name, value); } // 进行数据库连接信息封装,使用C3P0连接池 ComboPooledDataSource comboPooledDataSource = new ComboPooledDataSource(); comboPooledDataSource.setDriverClass(properties.getProperty(DRIVER)); comboPooledDataSource.setJdbcUrl(properties.getProperty(JDBCURL)); comboPooledDataSource.setUser(properties.getProperty(USERNAME)); comboPooledDataSource.setPassword(properties.getProperty(PASSWORD)); // 将数据源进行保存 configuration.setDataSource(comboPooledDataSource); // 进行mapper.xml的解析工作 // 首先得到需要加载的xml List<Element> mappers = rootElement.selectNodes("//mapper"); for (Element element : mappers) { // 得到需要加载的mapper文件路径 String mapperPath = element.attributeValue("resource"); // 对mapper文件进行解析,得到mapperstatement InputStream resourceAsSteam = Resources.getResourceAsSteam(mapperPath); XMLMapperBuilder xmlMapperBuilder = new XMLMapperBuilder(configuration); xmlMapperBuilder.parseMapper(resourceAsSteam); } // 返回整体配置对象 return this.configuration; } }

通过解析mapper.xml,生成对应的mappedStatement对象,存入Configuration对象中,因为全局只有一个Configuration对象,所以每次都把这个对象进行传递。

public class XMLMapperBuilder { private Configuration configuration; public XMLMapperBuilder(Configuration configuration) { this.configuration = configuration; } /** * 使用dom4j解析mapper.xml * @param inputStream */ public void parseMapper(InputStream inputStream) throws DocumentException { Document document = new SAXReader().read(inputStream); Element rootElement = document.getRootElement(); // 得到当前mapper的namespace String namespace = rootElement.attributeValue("namespace"); //递归遍历当前节点所有的子节点 List<Element> elementList = rootElement.elements(); for (Element element : elementList) { String id = element.attributeValue("id"); String resultType = element.attributeValue("resultType"); String parameterType = element.attributeValue("parameterType"); String sqlText = element.getTextTrim(); MappedStatement mappedStatement = new MappedStatement(); mappedStatement.setId(id); mappedStatement.setResultType(resultType); mappedStatement.setParameterType(parameterType); mappedStatement.setSql(sqlText); // 增加sql类型字段 mappedStatement.setSqlCommandType(element.getName()); // mapper的namespace和SQL语句的id,组成唯一id String key = namespace+"."+id; configuration.getMappedStatementMap().put(key,mappedStatement); } } }

到此,所有的配置文件都加载完成,最终根据生成的Configuration对象,返回对应的SqlSessionFactory对象:

public class DefaultSqlSessionFactory implements SqlSessionFactory{ // 通过构造方法注入,保证从上到下只有一个configuration对象 private Configuration configuration; public DefaultSqlSessionFactory(Configuration configuration) { this.configuration = configuration; } @Override public SqlSession openSession() { return new DefaultSqlSession(configuration); } }

当客户端调用openSession方法时,就会返回一个SqlSession对象,该对象里面封装了增删改查的方法,供客户端调用:

public class DefaultSqlSession implements SqlSession{ private Configuration configuration; public DefaultSqlSession(Configuration configuration) { this.configuration = configuration; } @Override public Object doQuery(String statementId, boolean resultTypeFlag, Object... params) throws Exception { // 根据statementId得到要执行的SQL对象 MappedStatement mappedStatement = configuration.getMappedStatementMap().get(statementId); // 判断当前的SQL类型是什么 switch (mappedStatement.getSqlCommandType()) { case "select": { if (resultTypeFlag) { return selectList(statementId, params); } else { return selectOne(statementId, params); } } case "insert":{ return insert(statementId, params); } case "update":{ return update(statementId, params); } case "delete":{ return delete(statementId, params); } } return null; } @Override public <E> List<E> selectList(String statementId, Object... params) throws Exception { // 根据statementId得到要执行的SQL对象 MappedStatement mappedStatement = configuration.getMappedStatementMap().get(statementId); Executor executor = new SimpleExecutor(); List<Object> result = executor.query(configuration, mappedStatement, params); return (List<E>) result; } @Override public <T> T selectOne(String statementId, Object... params) throws Exception { List<Object> objects = this.selectList(statementId, params); if (1 == objects.size()) { return (T) objects.get(0); } else { throw new RuntimeException("查询结果为空或查询结果过多"); } } @Override public int insert(String statementId, Object... params) throws Exception { MappedStatement mappedStatement = configuration.getMappedStatementMap().get(statementId); Executor executor = new SimpleExecutor(); executor.updateDatabase(configuration, mappedStatement, params); return 1; } @Override public int update(String statementId, Object... params) throws Exception { MappedStatement mappedStatement = configuration.getMappedStatementMap().get(statementId); Executor executor = new SimpleExecutor(); executor.updateDatabase(configuration, mappedStatement, params); return 1; } @Override public int delete(String statementId, Object... params) throws Exception { MappedStatement mappedStatement = configuration.getMappedStatementMap().get(statementId); Executor executor = new SimpleExecutor(); executor.updateDatabase(configuration, mappedStatement, params); return 1; } @Override public <T> T getMapper(Class<?> mapperClass) { // 根据JDK动态代理生成代理对象,对方法进行加工 Object proxyInstance = Proxy.newProxyInstance(DefaultSqlSession.class.getClassLoader(), new Class[]{mapperClass}, new InvocationHandler() { @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { // 这里无法获取mapper.xml的信息,所以为了方便识别,需要将namespace和SQL语句的id与接口的全限定名和方法保持一致 // 获取方法名 String methodName = method.getName(); // 获取所属的接口class名称 String className = method.getDeclaringClass().getName(); // 得到唯一的statementId String statementId = className + "." + methodName; // 根据方法的返回结果类型进行判断 Type genericReturnType = method.getGenericReturnType(); if(genericReturnType instanceof ParameterizedType){ return doQuery(statementId, true, args); } return doQuery(statementId, false, args); } }); return (T) proxyInstance; } }

最终的增删改查操作实际上还是通过JDBC来实现,所以声明一个Executor类来专门执行具体操作:

public class SimpleExecutor implements Executor{ @Override public <E> List<E> query(Configuration configuration, MappedStatement mappedStatement, Object... params) throws Exception { PreparedStatement preparedStatement = getPreparedStatement(configuration, mappedStatement, params); // 6 执行SQL,得到结果集 ResultSet resultSet = preparedStatement.executeQuery(); // 7 对结果集进行转换 List<Object> resultList = new ArrayList<>(); // 7.1 得到返回结果的类型 String resultType = mappedStatement.getResultType(); // 7.2 转换为类 Class<?> resultTypeClass = getClassType(resultType); // 7.3 遍历结果集,逐个进行转换 while (resultSet.next()) { // 声明返回类 Object o = resultTypeClass.newInstance(); // 获取元数据 ResultSetMetaData metaData = resultSet.getMetaData(); // 此处是从1开始 for (int i = 1; i <= metaData.getColumnCount(); i++) { // 字段名 String columnName = metaData.getColumnName(i); // 对应字段的值 Object value = resultSet.getObject(columnName); //使用内省,根据数据库表字段和实体属性的对应关系,完成封装 PropertyDescriptor propertyDescriptor = new PropertyDescriptor(columnName, resultTypeClass); // 获取写方法,进行值写入 Method writeMethod = propertyDescriptor.getWriteMethod(); writeMethod.invoke(o,value); } // 将本次转换的结果加入返回结果中 resultList.add(o); } return (List<E>) resultList; } @Override public boolean updateDatabase(Configuration configuration, MappedStatement mappedStatement, Object... params) throws ClassNotFoundException, SQLException, IllegalAccessException, NoSuchFieldException { PreparedStatement preparedStatement = getPreparedStatement(configuration, mappedStatement, params[0]); boolean execute = preparedStatement.execute(); return execute; } private PreparedStatement getPreparedStatement(Configuration configuration, MappedStatement mappedStatement, Object... params) throws SQLException, ClassNotFoundException, NoSuchFieldException, IllegalAccessException { // 执行jdbc过程 // 1 注册驱动获取连接,直接获取C3P0连接池里的连接 Connection connection = configuration.getDataSource().getConnection(); // 2 获取要执行的SQL String sql = mappedStatement.getSql(); // 3 对SQL中存在的参数进行提取和转换 BoundSql boundSql = getBoundSql(sql); // 4 获取预处理对象 PreparedStatement preparedStatement = connection.prepareStatement(boundSql.getSqlText()); // 5 设置参数 // 5.1 获取参数类型 String parameterType = mappedStatement.getParameterType(); // 5.2 根据类型获取类 Class<?> parameterTypeClass = getClassType(parameterType); // 5.3 得到解析的参数列表 List<ParameterMapping> parameterMappingList = boundSql.getParameterMappingList(); for (int i = 0; i < parameterMappingList.size(); i++) { ParameterMapping parameterMapping = parameterMappingList.get(i); // 得到参数值,即#{id}中的id String content = parameterMapping.getContent(); Field declaredField = null; // 判断是否是基本数据类型或者其包装类 if (isCommonDataType(parameterTypeClass) || isWrapClass(parameterTypeClass)) { declaredField = parameterTypeClass.getDeclaredField("value"); } else { declaredField = parameterTypeClass.getDeclaredField(content); } // 设置权限暴力访问,防止属性私有不让访问 declaredField.setAccessible(true); // 得到对应的值 Object o = declaredField.get(params[0]); // 将参数拼接到SQL上 preparedStatement.setObject(i + 1, o); } return preparedStatement; } /** * 判断当前类型是否是基本数据类型 * @param clazz * @return */ private Boolean isCommonDataType(Class clazz){ return clazz.isPrimitive(); } private boolean isWrapClass(Class clazz){ try { return ((Class) clazz.getField("TYPE").get(null)).isPrimitive(); } catch (Exception e) { return false; } } /** * 根据类型获取对应的class * @param type * @return */ private Class<?> getClassType(String type) throws ClassNotFoundException { if(type != null){ Class<?> clazz = Class.forName(type); return clazz; } return null; } /** * 对mapper中的原SQL进行解析和替换 * 由于jdbc只认识?占位符,所以要把#{id}进行替换 * 同时要得到其中的id,用于定位获取参数 * @param sql * @return */ private BoundSql getBoundSql(String sql) { // 标记处理类 ParameterMappingTokenHandler parameterMappingTokenHandler = new ParameterMappingTokenHandler(); // 第一个参数是开始标记、第二个是结束标记、第三个是使用那个处理类 GenericTokenParser genericTokenParser = new GenericTokenParser("#{", "}", parameterMappingTokenHandler); // 得到处理后的SQL(参数已经变为?) String formatSql = genericTokenParser.parse(sql); // 处理过程中,处理类已经将参数中的值进行了存储,直接获取即可 List<ParameterMapping> parameterMappings = parameterMappingTokenHandler.getParameterMappings(); // 通过构造方法进行赋值 BoundSql boundSql = new BoundSql(formatSql, parameterMappings); return boundSql; } }

还有一些基础的POJO这里就不再粘贴了,到此,框架的代码就基本开发完毕,下面进行测试。

4. 测试

在客户端里声明测试类,进行如下的代码编写:

private SqlSession sqlSession; // 在测试方法执行前执行 @Before public void prepare() throws Exception { // 将配置文件进行加载,得到字节流 InputStream resourceAsSteam = Resources.getResourceAsSteam("sqlMapConfig.xml"); // 此次加载会将xml文件里面的内容加载到框架内 // 通过字节流,得到sqlSession工厂 SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build(resourceAsSteam); // 通过工厂的openSession方法,生成一个session sqlSession = sqlSessionFactory.openSession(); } // 不使用代理的模式 @Test public void test() throws Exception { List<User> userList = sqlSession.selectList("com.ormtest.mapper.UserMapper.selectAll"); for (User user : userList) { System.out.println(user); } } // 使用代理模式 @Test public void test2() throws Exception { UserMapper userMapper = sqlSession.getMapper(UserMapper.class); List<User> userList = userMapper.selectAll(); for (User user : userList) { System.out.println(user); } }

最终测试通过,查询结果正确。

至此,自定义Mybatis的简单实现就完成了~~

最新回复(0)