欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

基于Mina2的Websocket实现

程序员文章站 2022-06-26 11:31:01
...

    在网上找了很多关于websocket协议的资料。我发现大部分的资料或是实现记录的都不完整,或者只给出了最基本的实现。于是,我花了一周的业余实现写了一个相对完整的实现。

    首先是解码器部分:

    

public class WSDecoder extends CumulativeProtocolDecoder {
	private final static String REQUEST_CONTEXT_KEY = "__REQUEST_DATA_CONTEXT";
	private final static String END_TAG = "\r\n";
	
	private  enum FrameType {
		Text,Binary,Control;
	}
	
	private class RequestDataContext {
		
		private IoBuffer _tmp;
		private CharsetDecoder _charsetDecoder;
		private FrameType _frameType;
		
		RequestDataContext(String charset) {
			this._tmp = IoBuffer.allocate(512).setAutoExpand(true);
			
			this._charsetDecoder = Charset.forName("utf-8").newDecoder();
		}
		
		public FrameType getFrameType() {
			return this._frameType;
		}
		
		public String getDataAsString() {
			try {
				_tmp.flip();
				return _tmp.getString(_charsetDecoder);
			} catch (CharacterCodingException e) {
				return null;
			}
		}
		
		public byte[] getDataAsArray() {
			_tmp.flip();
			
			byte[] data = new byte[_tmp.remaining()];
			_tmp.get(data);
			
			return data;
		}
		
		void append(byte[] data) {
			this._tmp.put(data);
		}
		
		void setFrameType(FrameType _frameType) {
			this._frameType = _frameType;
		}
	}

	private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN;

	public void setByteOrder(ByteOrder byteOrder) {
		this.byteOrder = byteOrder;
	}

	private CharsetDecoder charsetDecoder = Charset.forName("utf-8")
			.newDecoder();

	public void setCharsetDecoder(CharsetDecoder charsetDecoder) {
		this.charsetDecoder = charsetDecoder;
	}

	@Override
	protected boolean doDecode(IoSession session, IoBuffer in,
			ProtocolDecoderOutput decoderOutput)
			throws CharacterCodingException, NoSuchAlgorithmException {

		if (!in.hasRemaining())
			return false;

		WSSessionState state = getSessionState(session);

		switch (state) {
			case Handshake:
				doHandshake(session, in);
				break;
	
			case Connected:
				if (in.remaining() < 2)
					return false;
	
				in.mark().order(this.byteOrder);
	
				byte fstByte = in.get();
	
				int opCode = fstByte & 0xf;
				switch (opCode) {
					case 0x0:
					case 0x1:
					case 0x2:
						boolean isFinalFrame = fstByte < 0;
						boolean isRsvColZero = (fstByte & 0x70) == 0;
						if (!isRsvColZero) {
							closeConnection(session, in);
							break;
						}
		
						byte secByte = in.get();
						boolean isMasking = secByte < 0;
		
						int dataLength = 0;
						byte payload = (byte) (secByte & 0x7f);
						if (payload == 126)
							dataLength = in.getUnsignedShort();
						else if (payload == 127)
							dataLength = (int) in.getLong();
						else
							dataLength = payload;
		
						if (in.remaining() < (isMasking ? dataLength + 4 : dataLength)) {
							in.reset();
							return false;
						}
		
						byte[] mask = new byte[4];
						byte[] data = new byte[dataLength];
		
						if (isMasking)
							in.get(mask);
		
						in.get(data);
						
						// 用掩码处理数据。
						for( int i=0, maskLength=mask.length, looplimit=data.length; i<looplimit; i++ ) 
							data[i] = (byte)(data[i] ^ mask[i % maskLength]);
						
						// 创建一个对象保存“数据帧的数据类型”。协议规定——对于分片的数据只有第一帧会携带数据类型信息,所以要新建对象保存数据类型,以应对分片。
						RequestDataContext context = (RequestDataContext) session
								.getAttribute(REQUEST_CONTEXT_KEY);
						if (context == null) {
							context = new RequestDataContext(charsetDecoder.charset()
									.name());
							context.setFrameType((opCode == 0x1) ? FrameType.Text
									: FrameType.Binary);
		
							session.setAttribute(REQUEST_CONTEXT_KEY, context);
						}
						
						context.append(data);
		
						if (isFinalFrame) {
							context = (RequestDataContext) session.removeAttribute(REQUEST_CONTEXT_KEY);
							
							if (context.getFrameType() == FrameType.Text)
								decoderOutput.write(context.getDataAsString());
							else
								decoderOutput.write(context.getDataAsArray());
							
							return true;
						}
						else
							return false;
		
					case 0x3:
					case 0x4:
					case 0x5:
					case 0x6:
					case 0x7:
						break;
		
					case 0x8:
						closeConnection(session, in);
						break;
						
					case 0x9:
					case 0xA:
					default:
						closeConnection(session, in);
						break;
				}
	
				break;
	
			default:
				closeConnection(session, in);
				break;
		}
		return true;
	}

	private void doHandshake(IoSession session, IoBuffer in)
			throws CharacterCodingException, NoSuchAlgorithmException {
		String handshakeMessage = in.getString(charsetDecoder);
		String[] msgColumns = splitHandshakeMessage(handshakeMessage);

		String requestURI = msgColumns[0];
		String httpVersion = requestURI.substring(
				requestURI.lastIndexOf("/") + 1, requestURI.length());

		String upgradeCol = getMessageColumnValue(msgColumns, "Upgrade:");
		String connectionCol = getMessageColumnValue(msgColumns, "Connection:");
		String secWsProtocolCol = getMessageColumnValue(msgColumns,
				"Sec-WebSocket-Protocol:");
		String secWskeyCol = getMessageColumnValue(msgColumns,
				"Sec-WebSocket-Key:");
		String wsVersionCol = getMessageColumnValue(msgColumns,
				"Sec-WebSocket-Version:");

		// 校验重要字段。任何字段不满足条件,都会导致握手失败!
		boolean hasWebsocket = contains(upgradeCol, "websocket");
		boolean hasUpgrade = contains(connectionCol, "upgrade");
		boolean isGetMethod = "GET".equalsIgnoreCase(subString(requestURI, 1,
				" "));

		boolean isSecWsKeyNull = secWskeyCol == null || secWskeyCol.isEmpty();
		boolean isValidVersion = "13".equals(wsVersionCol);
		boolean isValidHttpVer = Float.parseFloat(httpVersion) >= 1.1F;

		if (!hasWebsocket || !hasUpgrade || !isGetMethod)
			throw new WSException("Invalid websocket request!");

		if (isSecWsKeyNull || !isValidVersion || !isValidHttpVer)
			throw new WSException("Invalid websocket request!");

		String secWsAccept = getSecWebsocketAccept(secWskeyCol);
		String response = getResponseData(upgradeCol, connectionCol,
				secWsAccept, secWsProtocolCol);

		initRequestContext(session, msgColumns);

		session.write(response);
	}

	private String[] splitHandshakeMessage(String handshakeMessage) {
		StringTokenizer st = new StringTokenizer(handshakeMessage, END_TAG);
		String[] result = new String[st.countTokens()];

		for (int i = 0; st.hasMoreTokens(); i++)
			result[i] = st.nextToken();

		return result;
	}

	private boolean contains(String src, String target) {
		if (src == null || src.isEmpty())
			return false;
		else
			return src.toLowerCase().contains(target);
	}

	private String getSecWebsocketAccept(String secWebsocketkey)
			throws NoSuchAlgorithmException {

		StringBuilder srcBuilder = new StringBuilder();
		srcBuilder.append(secWebsocketkey);
		srcBuilder.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");

		MessageDigest md = MessageDigest.getInstance("SHA-1");

		md.update(srcBuilder.toString().getBytes(charsetDecoder.charset()));
		byte[] ciphertext = md.digest();

		String result = new String(Base64.encodeBase64(ciphertext),
				charsetDecoder.charset());
		return result;
	}

	private String getResponseData(String upgrade, String connection,
			String secWsAccept, String secWsProtocol) {

		StringBuilder result = new StringBuilder();

		result.append("HTTP/1.1 101 Switching Protocols\r\n");
		result.append("Upgrade:").append(upgrade).append(END_TAG);
		result.append("Connection:").append(connection).append(END_TAG);
		result.append("Sec-WebSocket-Accept:").append(secWsAccept)
				.append(END_TAG);

		if (secWsProtocol != null && !"".equals(secWsProtocol))
			result.append("Sec-WebSocket-Protocol:").append(secWsProtocol)
					.append(END_TAG);

		result.append(END_TAG);

		return result.toString();
	}

	private void closeConnection(IoSession session, IoBuffer buffer) {
		buffer.free();
		session.close(true);
	}

	@SuppressWarnings("unused")
	private void closeConnection(IoSession session, String errorMsg) {
		session.write(errorMsg).addListener(
				new IoFutureListener<WriteFuture>() {
					@Override
					public void operationComplete(WriteFuture future) {
						future.getSession().close(true);
					}
				});
	}

	private void initRequestContext(IoSession session, String[] data) {
		session.setAttribute("__SESSION_CONTEXT", data);
	}
}

     

    下面是编码器部分:

    

public class WSEncoder extends ProtocolEncoderAdapter {

	private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN;

	public void setByteOrder(ByteOrder byteOrder) {
		this.byteOrder = byteOrder;
	}

	private CharsetEncoder charsetEncoder = Charset.forName("utf-8")
			.newEncoder();

	public void setCharsetEncoder(CharsetEncoder charsetEncoder) {
		this.charsetEncoder = charsetEncoder;
	}

	private int defaultPageSize = 65536;

	public void setDefaultPageSize(int defaultPageSize) {
		this.defaultPageSize = defaultPageSize;
	}
	
	// 返回的数据默认不必进行掩码运算。
	private boolean isMasking = false;
	public void setIsMasking(boolean masking) {
		this.isMasking = masking;
	}

	@Override
	public void encode(IoSession session, Object message, ProtocolEncoderOutput encoderOutput) throws CharacterCodingException {
		IoBuffer buff = IoBuffer.allocate(1024).setAutoExpand(true);

		WSSessionState status = getSessionState(session);
		
		switch (status) {
			case Handshake:
				try {
					buff.putString((String) message, charsetEncoder)
						.flip();

					encoderOutput.write(buff);
				} catch (CharacterCodingException e) {
					session.close(true);
				}

				session.setAttribute(WSSessionState.ATTRIBUTE_KEY, WSSessionState.Connected);
				break;
	
			case Connected:
				if (!session.isConnected() || message == null)
					return;
				
				byte dataType = 1;
				
				// 将数据统一转换成byte数组进行处理。
				byte[] data = null;
				if (message instanceof String)
					data = ((String) message).getBytes(charsetEncoder.charset());
				else {
					data = (byte[]) message;
					dataType = 2;
				}
				
				// 生成掩码。
				byte[] mask = new byte[4];
				Random random = new Random();
				
				// 用掩码处理数据。
				for( int i=0,limit=data.length; i<limit; i++)
					data[i] = (byte)(data[i] ^ mask[i%4]);
				
				/**
				 * 以分片的方式向客户端推送数据。
				 */
				int pageSize = this.defaultPageSize;	// 分页大小
				int remainLength = data.length;			// 剩余数据长度
				int realLength = 0;						// 数据帧中“负载数据”的实际长度
				int dataIndex = 0;
				for( boolean isFirstFrame=true, isFinalFrame = false; !isFinalFrame; buff.clear(), isFirstFrame=false) {
					
					int headerLeng = 2;
					
					int payload = 0;
					if (remainLength > 0 && remainLength <= 125) {
						payload = remainLength;
					} else if (remainLength > 125 && remainLength <=65536) {
						payload = 126;
					} else {
						payload = 127;
					}
					
					switch(payload) {
						case 126 : 
							headerLeng += 2;
							break;
							
						case 127 :
							headerLeng += 8;
							break;
							
						default  :
							headerLeng += 0;
							break;
					}
					
					headerLeng += isMasking ? 4 : 0;
					
					// 计算当前帧中剩余的可用于保存“负载数据”的字节长度。
					realLength = ( pageSize - headerLeng )  >= remainLength ? remainLength : ( pageSize - headerLeng );
					
					
					// 判断当前帧是否为最后一帧。
					isFinalFrame = (remainLength - (pageSize - headerLeng)) < 0;
					
					// 生成第一个字节
					byte fstByte = (byte)(isFinalFrame ? 0x80 : 0x0);
					// 若当前帧为第一帧,则还需保存数据类型信息。
					fstByte += isFirstFrame ? dataType : 0;
					buff.put(fstByte);
					
					// 生成第二个字节。判断是否需要掩码,若需要掩码,则置标志位的值为1.
					byte sndByte = (byte)(isMasking ? 0x80 : 0);
					// 保存payload信息。
					sndByte += payload;
					buff.put(sndByte);
					
					switch(payload) {
						case 126 : 
							buff.putUnsignedShort(realLength);
							break;
							
						case 127 :
							buff.putLong(realLength);
							break;
							
						default  :
							break;
					}
					
					if (isMasking) {
						random.nextBytes(mask);
						buff.put(mask);
					}
					
					for(int loopCount=dataIndex+realLength, i=0; dataIndex<loopCount; dataIndex++,i++)
						buff.put((byte)(data[i] ^ mask[i%4]));
					
					buff.flip();
					encoderOutput.write(buff);
					
					remainLength -= (pageSize - headerLeng);
				}
				
				break;
	
			default:
				session.close(true);
				break;
		}
	}
}

     

    下面是工具类实现:

    

final class WSToolKit {
	private WSToolKit() {
	}

	static enum WSSessionState {
		Handshake, Connected;
		public final static String ATTRIBUTE_KEY = "__SESSION_STATE";
	}

	static <T> T nvl(T t1, T t2) {
		return t1 == null ? t2 : t1;
	}
	
	static WSSessionState getSessionState (IoSession session) {
		WSSessionState state = (WSSessionState) session
				.getAttribute(WSSessionState.ATTRIBUTE_KEY);
		
		if (state == null) {
			state = WSSessionState.Handshake;
			session.setAttribute(WSSessionState.ATTRIBUTE_KEY, state);
		}
		
		return state;
	}
	
	static String getMessageColumnValue(String[] headers, String headerTag) {
		for (String header : headers) {
			if( header.startsWith(headerTag) )
				return header.substring(headerTag.length(), header.length())
						.trim();
		}

		return null;
	}
	
	static String subString(String src, int order, String flag) {
		for (int i = 1, j = 0, k = 0;; i++) {
			j = src.indexOf(flag, k);
			if (i < order) {
				if (j == -1)
					return "";
				else
					k = j + 1;
			} else {
				if (j == -1)
					return src.substring(k, src.length());
				else
					return src.substring(k, j);
			}
		}
	}
}

     

    如果大家想转载,请注明出处!谢谢。