欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

spring security oauth2的token续期

程序员文章站 2022-03-26 16:45:17
需求描述: 如果用户在指定的时间内有操作就给token延长有限期,否则到期后自动过期;我这里是用Redis存储数据的,所以需要重写RedisTokenStore,以下为实现代码:import lombok.extern.slf4j.Slf4j;import org.springframework.data.redis.connection.RedisConnection;import org.springframework.data.redis.connection.RedisConnect....

需求描述: 如果用户在指定的时间内有操作就给token延长有限期,否则到期后自动过期;

我这里是用Redis存储数据的,所以需要重写RedisTokenStore,以下为实现代码:

import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.ExpiringOAuth2RefreshToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.store.redis.RedisTokenStore;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;

@Slf4j
public class MyRedisTokenStore extends RedisTokenStore {
	private static Class<?> superClazz = RedisTokenStore.class;
	private static final class KEYS{
		private static final String ACCESS = "access:";
		private static final String AUTH_TO_ACCESS = "auth_to_access:";
		private static final String AUTH = "auth:";
		private static final String ACCESS_TO_REFRESH = "access_to_refresh:";
		private static final String REFRESH_TO_ACCESS = "refresh_to_access:";
		private static final String CLIENT_ID_TO_ACCESS = "client_id_to_access:";
		private static final String UNAME_TO_ACCESS = "uname_to_access:";
	}
	
	//属性值和方法缓存
	private static Map<String, Object> fieldsValCache = new HashMap<>();
	private static Map<String, Method> methodsCache = new HashMap<>();

	public MyRedisTokenStore(RedisConnectionFactory connectionFactory) {
		super(connectionFactory);
	}

	@Override
	public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
		OAuth2Authentication result = readAuthentication(token.getValue());

		if (result != null) {
			log.debug("进入token续签,token---{}, result---{}", token.getValue(), result);

			// 如果token没有失效  更新AccessToken过期时间
			DefaultOAuth2AccessToken oAuth2AccessToken = (DefaultOAuth2AccessToken) token;

			//重新设置过期时间
			User user = (User) result.getPrincipal();
			String userName = user.getUsername();
			String clientId = result.getOAuth2Request().getClientId();

			log.debug("userName---{},clientId---{}", userName, clientId);

			Collection<OAuth2AccessToken> tokensByClientIdAndUserName = findTokensByClientIdAndUserName(clientId, userName);
			OAuth2AccessToken o = (OAuth2AccessToken) tokensByClientIdAndUserName.toArray()[0];
			int expiresIn = o.getExpiresIn();

			if (expiresIn > 0) {
				oAuth2AccessToken.setExpiration(new Date(System.currentTimeMillis() + (3600 * 24 * 2 * 1000L)));
			}

			log.debug("userName---{},clientId---{},expiresIn{}", userName, clientId, expiresIn);

			//续期, 网上有人用storeAccessToken,是有问题的,
			//tokenRenewal相当于重写了storeAccessToken,只是在集合添加数据时,删除了旧数据
			//不然会造成CLIENT_ID_TO_ACCESS和UNAME_TO_ACCESS中的数据一直累计,造成redis大Key的问题
			tokenRenewal(token, result);
		}
		return result;
	}

	/**
	 * token续期
	 */
	private void tokenRenewal(OAuth2AccessToken token, OAuth2Authentication authentication){
		AuthenticationKeyGenerator autKeyGenerator = getSuperFieldVal("authenticationKeyGenerator");
		Method getApprovalKey = getSuperMethod("getApprovalKey", OAuth2Authentication.class);
		Method serializeKey = getSuperMethod("serializeKey", String.class);
		Method serialize = getSuperMethod("serialize", Object.class);
		Method getConnection = getSuperMethod("getConnection");


		String ataKey = KEYS.AUTH_TO_ACCESS + autKeyGenerator.extractKey(authentication);
		String utaKey = KEYS.UNAME_TO_ACCESS + invokeSuperMethod(getApprovalKey, authentication);
		String citaKey = KEYS.CLIENT_ID_TO_ACCESS + authentication.getOAuth2Request().getClientId();

		byte[] accessKey = invokeSuperMethod(serializeKey, KEYS.ACCESS + token.getValue());
		byte[] authKey = invokeSuperMethod(serializeKey, KEYS.AUTH + token.getValue());
		byte[] authToAccessKey = invokeSuperMethod(serializeKey, ataKey);
		byte[] approvalKey = invokeSuperMethod(serializeKey, utaKey);
		byte[] clientId = invokeSuperMethod(serializeKey, citaKey);

		byte[] serializedAuth = invokeSuperMethod(serialize, authentication);
		byte[] serializedAccessToken = invokeSuperMethod(serialize, token);
		Boolean springRedis_2_0 = getSuperFieldVal("springDataRedis_2_0");
		Method redisConnSet_2_0 = getSuperFieldVal("redisConnectionSet_2_0");

		RedisConnection conn = invokeSuperMethod(getConnection);

		byte[] access = conn.get(accessKey);

		try {
			conn.openPipeline();
			if (springRedis_2_0) {
				try {
					redisConnSet_2_0.invoke(conn, accessKey, serializedAccessToken);
					redisConnSet_2_0.invoke(conn, authKey, serializedAuth);
					redisConnSet_2_0.invoke(conn, authToAccessKey, serializedAccessToken);
				} catch (Exception ex) {
					throw new RuntimeException(ex);
				}
			} else {
				conn.set(accessKey, serializedAccessToken);
				conn.set(authKey, serializedAuth);
				conn.set(authToAccessKey, serializedAccessToken);
			}

			if (!authentication.isClientOnly()) {
				//先删除原有的
				conn.sRem(approvalKey, access);
				conn.sAdd(approvalKey, serializedAccessToken);
			}
			//先删除原有的
			conn.sRem(clientId, access);
			conn.sAdd(clientId, serializedAccessToken);

			if (token.getExpiration() != null) {
				int seconds = token.getExpiresIn();
				conn.expire(accessKey, seconds);
				conn.expire(authKey, seconds);
				conn.expire(authToAccessKey, seconds);
				conn.expire(clientId, seconds);
				conn.expire(approvalKey, seconds);
			}
			OAuth2RefreshToken refreshToken = token.getRefreshToken();
			if (refreshToken != null && refreshToken.getValue() != null) {
				String rtaKey = KEYS.REFRESH_TO_ACCESS + token.getRefreshToken().getValue();
				String atrKey = KEYS.ACCESS_TO_REFRESH + token.getValue();
				byte[] refreshToAccessKey = invokeSuperMethod(serializeKey, rtaKey);
				byte[] accessToRefreshKey = invokeSuperMethod(serializeKey, atrKey);
				byte[] auth = invokeSuperMethod(serialize, token.getValue());
				byte[] refresh = invokeSuperMethod(serialize, token.getRefreshToken().getValue());

				if (springRedis_2_0) {
					try {
						redisConnSet_2_0.invoke(conn, refreshToAccessKey, auth);
						redisConnSet_2_0.invoke(conn, accessToRefreshKey, refresh);
					} catch (Exception ex) {
						throw new RuntimeException(ex);
					}
				} else {
					conn.set(refreshToAccessKey, auth);
					conn.set(accessToRefreshKey, refresh);
				}

				if (refreshToken instanceof ExpiringOAuth2RefreshToken) {
					ExpiringOAuth2RefreshToken expiringRefreshToken = (ExpiringOAuth2RefreshToken) refreshToken;
					Date expiration = expiringRefreshToken.getExpiration();
					if (expiration != null) {
						int seconds = Long.valueOf((expiration.getTime() - System.currentTimeMillis()) / 1000L).intValue();
						conn.expire(refreshToAccessKey, seconds);
						conn.expire(accessToRefreshKey, seconds);
					}
				}
			}
			conn.closePipeline();
		} finally {
			conn.close();
		}
	}

	//以下方法用于协助调用父类的私有属性和方法

	/**
	 * 获取父类私有方法
	 * @param methodName 方法名
	 * @param params 方法参数类型列表
	 * @return 方法
	 */
	private Method getSuperMethod(String methodName, Class<?>... params){
		try{
			Method method;
			if((method=methodsCache.get(methodName)) == null){
				method = superClazz.getDeclaredMethod(methodName, params);
				method.setAccessible(true);
				methodsCache.put(methodName, method);
			}
			return method;
		}catch (NoSuchMethodException e){
			e.printStackTrace();
			return null;
		}
	}

	/**
	 * 调用父类私有方法
	 * @param method 方法
	 * @param args 参数列表
	 * @param <T> 指定返回类型
	 * @return 方法调用结果
	 */
	private <T> T invokeSuperMethod(Method method, Object... args){
		try{
			return (T)method.invoke(this, args);
		}catch (Exception e){
			e.printStackTrace();
			return null;
		}
	}

	/**
	 * 获取父类私有属性的值
	 * @param fieldName 属性名
	 * @param <E> 指定属性类型
	 * @return 属性值
	 */
	private <E> E getSuperFieldVal(String fieldName){
		try{
			Object result;
			if((result=fieldsValCache.get(fieldName)) == null){
				Field field = superClazz.getDeclaredField(fieldName);
				field.setAccessible(true);
				result = field.get(this);
				fieldsValCache.put(fieldName, result);
			}
			return (E)result;
		}catch (Exception e){
			e.printStackTrace();
			return null;
		}
	}
}

tokenRenewal相当于在storeAccessToken内添加了以下三行代码:

byte[] access = conn.get(accessKey);
conn.sRem(approvalKey, access);
conn.sRem(clientId, access);

本文地址:https://blog.csdn.net/xhom_w/article/details/110229681