客户端安全传输密码至服务端的实现改进
程序员文章站
2022-04-23 10:29:25
两年前在做Java EE开发平台时,因为用户登录相关的模块是委托给另一位同事完成的,所以虽然知道大体概念,但是对客户端怎么安全传输密码到服务端的具体细节并不甚了解。然而这次在做4A系统(认证、授权、监控、审计)时,无论怎样都绕不过这一块内容了,于是在仔细研究了一下之前的方案,并参考网上的一些资料后, ......
两年前在做Java EE开发平台时,因为用户登录相关的模块是委托给另一位同事完成的,所以虽然知道大体概念,但是对客户端怎么安全传输密码到服务端的具体细节并不甚了解。然而这次在做4A系统(认证、授权、监控、审计)时,无论怎样都绕不过这一块内容了,于是在仔细研究了一下之前的方案,并参考网上的一些资料后,做了一些改进,特此记录一下。
总体方案
- 服务端生成RSA密钥对,并将公钥返回给客户端
- 客户端在提交登录时,将密码使用公钥加密,传输给服务端
- 服务端收到登录请求后,使用私钥解密,并进行下一步处理
方案很简单,也很容易理解,只需要知道非对称加密的一般概念即可,但是在具体的实现时还是有一些细节需要注意的。
原有的实现(实际上同事也是参考了网上的资料)
- 创建一个表
BF_KEY_CFG(module,public_empoent,private_empoent VARCHAR(200))
这里需要对RSA的具体实现有一点点了解,知道模、公钥指数、私钥指数等概念。 - 引入加密工具包
bouncycastle-1.0.jar
,在系统每次启动的时候,利用该工具包生成密钥对(模、公钥指数、私钥指数)并保存到数据库 - 在客户端提交登录时,先从服务器获取公钥,然后再加密发送登录请求
- 服务端在进行密码验证前,先用私钥解密,然后再进行其他认证处理
这个实现有几个问题:
- 创建的表包含三个字段(模、公钥指数、私钥指数)并不那么容易理解,不像公钥、私钥那么大众化
- 需要额外引入第三方包
- 多个服务端节点时,启动时会有潜在的冲突
- 使用的JS加密算法,在服务端解密时,出现顺序颠倒的问题,因此也不能很好的处理中文等字符
- 在提交登录时,是重写了原来的密码表单域的值,导致不能记住密码(每次加密的公钥不相同)
- 在服务端,也不能很好的利用Spring MVC的表单校验,比如需要自己单独写逻辑判断密码长度不能大于16等
此外,网上的资料大都是你抄我、我抄你,一些基本的工具类代码都很丑陋,到处充斥着重复的代码。
改进后的实现
改进主要是针对上面提到的问题进行的。
- 首先,表字段只保留
PUBLIC_KEY、PRIVATE_KEY
- 其次,使用JDK原生的加密
provider
- 解决JS和JAVA加解密顺序颠倒的问题,从而也解决了中文加密问题
- 在前端设置一个隐藏域,提交登录表单时,先将隐藏域的值设置为加密后的密码,然后将表单域设置为disable(从而不会传送至后台),最后再提交登录表单,登录失败返回登录页时重新启用密码域,这样可以解决记住密码,回退到登录页等许多问题
- 在Spring MVC绑定参数后,进行参数校验前,进行解密,从而可以利用Spring MVC的原生校验
这几个改进,前端使用JSEncrypt库配合jQuery非常容易实现
后端则主要包括两个部分:加密工具类、Spring MVC参数绑定后和参数校验前的处理逻辑插入
(1)加密工具类:
import java.security.NoSuchAlgorithmException; import java.security.Provider; import java.security.SecureRandom; import javax.crypto.Cipher; import javax.crypto.NoSuchPaddingException; import org.springframework.util.Base64Utils;//在JDK8中可以直接使用JDK原生Base64工具类 /* package */ class CryptoUtils { protected static final SecureRandom random = new SecureRandom(); /** * 获取Provider对象:注意单例 * * 方便使用其它的Provider * @return */ protected static Provider getProvider() { return null;//provider; } /** * 字节转换为字符串 * * @param data * @return */ protected static String encodeToString(byte[] data) { return Base64Utils.encodeToString(data); } /** * 字符串转换为字节 * * @param src * @return */ protected static byte[] decodeFromString(String src) { return Base64Utils.decodeFromString(src); } /** * 获取Cipher对象 * * @param keyAlgorithm * @return */ protected static Cipher getCipher(String keyAlgorithm) { try { Provider provider = getProvider(); Cipher cipher = null; if (null == provider) { cipher = Cipher.getInstance(keyAlgorithm); } else { cipher = Cipher.getInstance(keyAlgorithm, provider); } return cipher; } catch (NoSuchAlgorithmException e) { throw new RuntimeException(e); } catch (NoSuchPaddingException e) { throw new RuntimeException(e); } } }
import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.Serializable; import java.security.InvalidKeyException; import java.security.Key; import java.security.KeyFactory; import java.security.KeyPair; import java.security.KeyPairGenerator; import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; import java.security.Provider; import java.security.PublicKey; import java.security.Signature; import java.security.SignatureException; import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; import java.security.spec.X509EncodedKeySpec; import javax.crypto.BadPaddingException; import javax.crypto.Cipher; import javax.crypto.IllegalBlockSizeException; public class RSAUtils extends CryptoUtils { /** * 加密算法RSA */ private static final String KEY_ALGORITHM = "RSA"; /** * 生成RSA密钥对 */ public static RsaKeyPair getRsaKeyPair() { try { KeyPairGenerator keyPairGen = null; Provider provider = getProvider(); if (null == provider) { keyPairGen = KeyPairGenerator.getInstance(KEY_ALGORITHM); } else { keyPairGen = KeyPairGenerator.getInstance(KEY_ALGORITHM, provider); } keyPairGen.initialize(1024, random); KeyPair keyPair = keyPairGen.generateKeyPair(); String privateKey = encodeToString(keyPair.getPrivate().getEncoded()); String publicKey = encodeToString(keyPair.getPublic().getEncoded()); return new RsaKeyPair(publicKey, privateKey); } catch (NoSuchAlgorithmException e) { throw new RuntimeException(e); } } /** * 获取公钥 * * @param publicKey * @return */ public static PublicKey getPublicKey(String publicKey) { try { byte[] bytes = decodeFromString(publicKey); X509EncodedKeySpec x509KeySpec = new X509EncodedKeySpec(bytes); PublicKey pk = getKeyFactory().generatePublic(x509KeySpec); return pk; } catch (InvalidKeySpecException e) { throw new RuntimeException(e); } } /** * 获取私钥 * * @param privateKey * @return */ public static PrivateKey getPrivateKey(String privateKey) { try { byte[] bytes = decodeFromString(privateKey); PKCS8EncodedKeySpec pkcs8KeySpec = new PKCS8EncodedKeySpec(bytes); PrivateKey pk = getKeyFactory().generatePrivate(pkcs8KeySpec); return pk; } catch (InvalidKeySpecException e) { throw new RuntimeException(e); } } /** * 公钥加密 * * @param plainText * @param publicKey * @return */ public static String encryptByPublicKey(String plainText, String publicKey) { Cipher cipher = getCipher(true, publicKey, true); byte[] src = plainText.getBytes(); byte[] datas = cipherData(cipher, src, 117); return encodeToString(datas); } /** * 私钥加密 * * @param plainText * @param privateKey * @return */ public static String encryptByPrivateKey(String plainText, String privateKey) { Cipher cipher = getCipher(true, privateKey, false); byte[] src = plainText.getBytes(); byte[] datas = cipherData(cipher, src, 117); return encodeToString(datas); } /** * 公钥解密 * * @param plainText * @param publicKey * @return */ public static String decryptByPublicKey(String encryptText, String publicKey) { Cipher cipher = getCipher(false, publicKey, true); byte[] src = decodeFromString(encryptText); byte[] datas = cipherData(cipher, src, 128); return new String(datas); } /** * 私钥解密 * * @param encryptText * @param privateKey * @return */ public static String decryptByPrivateKey(String encryptText, String privateKey) { Cipher cipher = getCipher(false, privateKey, false); byte[] src = decodeFromString(encryptText); byte[] datas = cipherData(cipher, src, 128); return new String(datas); } /** * 获取KeyFactory对象 * * @return */ private static KeyFactory getKeyFactory() { try { Provider provider = getProvider(); if (null == provider) { return KeyFactory.getInstance(KEY_ALGORITHM); } else { return KeyFactory.getInstance(KEY_ALGORITHM, provider); } } catch (NoSuchAlgorithmException e) { throw new RuntimeException(e); } } /** * 获取Cipher对象 * * @param encode * @param keyString * @param isPublicKey * @return */ private static Cipher getCipher(boolean encode, String keyString, boolean isPublicKey) { try { Key key = isPublicKey ? getPublicKey(keyString) : getPrivateKey(keyString); Cipher cipher = getCipher(KEY_ALGORITHM); cipher.init(encode ? Cipher.ENCRYPT_MODE : Cipher.DECRYPT_MODE, key); return cipher; } catch (InvalidKeyException e) { throw new RuntimeException(e); } } /** * 加解密数据 * * @param cipher * @param src * @param blockSize * @return */ private static byte[] cipherData(Cipher cipher, byte[] src, int blockSize) { ByteArrayOutputStream out = null; try { int inputLen = src.length; out = new ByteArrayOutputStream(); int offSet = 0; byte[] cache; int i = 0; while (inputLen - offSet > 0) { if (inputLen - offSet > blockSize) { cache = cipher.doFinal(src, offSet, blockSize); } else { cache = cipher.doFinal(src, offSet, inputLen - offSet); } out.write(cache, 0, cache.length); i++; offSet = i * blockSize; } return out.toByteArray(); } catch (IllegalBlockSizeException e) { throw new RuntimeException(e); } catch (BadPaddingException e) { throw new RuntimeException(e); } finally { if (null != out) { try { out.close(); } catch (IOException e) { e.printStackTrace(); } } } } /** * RSA密钥对象 */ public static class RsaKeyPair implements Serializable { private static final long serialVersionUID = 2130689702406754025L; /** * 公钥 */ private String publicKey; /** * 私钥 */ private String privateKey; public RsaKeyPair() {} public RsaKeyPair(String publicKey, String privateKey) { this.publicKey = publicKey; this.privateKey = privateKey; } public String getPublicKey() { return publicKey; } public void setPublicKey(String publicKey) { this.publicKey = publicKey; } public String getPrivateKey() { return privateKey; } public void setPrivateKey(String privateKey) { this.privateKey = privateKey; } } }
(2)参数绑定后、校验前的逻辑插入
实际上只要跟踪调试一下,就可以发现关键在于ExtendedServletRequestDataBinder
,覆盖两个方法即可:
import javax.servlet.ServletRequest; import org.springframework.beans.MutablePropertyValues; import org.springframework.validation.AbstractPropertyBindingResult; import org.springframework.web.servlet.mvc.method.annotation.ExtendedServletRequestDataBinder; /* package */ class RelaxedServletRequestDataBinder extends ExtendedServletRequestDataBinder { public RelaxedServletRequestDataBinder(Object target) { super(target); } public RelaxedServletRequestDataBinder(Object target, String objectName) { super(target, objectName); } /** * 添加属性值提供器的相关处理 */ @Override protected void addBindValues(MutablePropertyValues mpvs, ServletRequest request) { super.addBindValues(mpvs, request); PropertyValuesProviders.addBindValues(mpvs, request, getTarget(), getObjectName()); } /** * 常规绑定之后的处理 */ @Override public void bind(ServletRequest request) { super.bind(request); PropertyValuesProviders.afterBindValues(getPropertyAccessor(), request, getTarget(), getObjectName()); } @Override protected AbstractPropertyBindingResult createBeanPropertyBindingResult() { return new RelaxedBeanPropertyBindingResult(getTarget(), getObjectName(), isAutoGrowNestedPaths(), getAutoGrowCollectionLimit()); } }
这个类只是提供一个入口,真正的逻辑委托给了PropertyValuesProviders
处理。在我的实现中,这里提取了一个接口:
import javax.servlet.ServletRequest; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.PropertyAccessor; public interface IPropertyValuesProvider {//感谢JDK8中的默认实现 default void addBindValues(MutablePropertyValues mpvs, ServletRequest request, Object target, String name) {}; default void afterBindValues(PropertyAccessor accessor, ServletRequest request, Object target, String name) {}; }
import java.util.List; import javax.servlet.ServletRequest; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.PropertyAccessor; import org.springframework.beans.factory.annotation.Autowired; public class PropertyValuesProviders { private static List<IPropertyValuesProvider> providers; @Autowired(required = false) public void setProviders(List<IPropertyValuesProvider> providers) { if (providers != null) { PropertyValuesProviders.providers = providers; } } public static void addBindValues(MutablePropertyValues mpvs, ServletRequest request, Object target, String name) { if (null != providers) { for (IPropertyValuesProvider provider : providers) { provider.addBindValues(mpvs, request, target, name); } } } public static void afterBindValues(PropertyAccessor accessor, ServletRequest request, Object target, String name) { if (null != providers) { for (IPropertyValuesProvider provider : providers) { provider.afterBindValues(accessor, request, target, name); } } } }
然后添加RSA解密的实现:
import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import javax.servlet.ServletRequest; import org.springframework.beans.PropertyAccessor; import com.autumn.platform.core.crypto.RSAUtils; import com.autumn.platform.core.crypto.RSAUtils.RsaKeyPair; import com.autumn.platform.core.logger.Logs; import com.autumn.platform.web.annotation.RsaDecrypt; public abstract class AbstractRsaDecryptPropertyValuesProvider implements IPropertyValuesProvider { private Map<Class<?>, List<Field>> cache = new ConcurrentHashMap<Class<?>, List<Field>>(); @Override public void afterBindValues(PropertyAccessor accessor, ServletRequest request, Object target, String name) { RsaKeyPair pair = this.getRsaKeyPair(accessor, request, target, name); if (null == pair) { return; } for (Class<?> cls = target.getClass(); !cls.equals(Object.class); cls = cls.getSuperclass()) { List<Field> fields = resolveFields(cls); if (null != fields && !fields.isEmpty()) { for (Field field : fields) { setDecryptValue(target, pair, field); } } } } abstract protected RsaKeyPair getRsaKeyPair(PropertyAccessor accessor, ServletRequest request, Object target, String name); private void setDecryptValue(Object target, RsaKeyPair pair, Field field) { try { Object value = field.get(target); if (value instanceof String) { String text = RSAUtils.decryptByPrivateKey((String) value, pair.getPrivateKey()); field.set(target, text); } } catch (Exception e) { Logs.error("解密" + field + "的值时出现异常", e); } } private List<Field> resolveFields(Class<?> cls) { List<Field> fieldList = cache.get(cls); if (null == fieldList) { fieldList = new ArrayList<Field>(); Field[] fields = cls.getDeclaredFields(); if (null != fields) { for (Field field : fields) { if (field.isAnnotationPresent(RsaDecrypt.class)) { if (!field.isAccessible()) { field.setAccessible(true); } fieldList.add(field); } } } if (fieldList.isEmpty()) { fieldList = Collections.emptyList(); } cache.put(cls, fieldList); } return fieldList; } }
其中@RsaDecrypt
注解是自定义的,用于表示这个表单字段在接受参数时,需要先使用RSA解密。