# This file is part of NIT ( http://www.nitlanguage.org ).
#
-# Copyright 2014 Lucas Bajolet <r4pass@hotmail.com>
-#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import socket
import sha1
import base64
+import crypto
-intrude import standard::stream
-
-# Websocket compatible listener
+# Websocket compatible server
#
# Produces Websocket client-server connections
-class WebSocketListener
- super Socket
+class WebsocketServer
- # Socket listening to connections on a defined port
+ # Socket listening for incoming Websocket connections
var listener: TCPServer
- # Creates a new Websocket server listening on given port with `max_clients` slots available
- init(port: Int, max_clients: Int)
+ # Is `self` closed?
+ var closed = false
+
+ # Creates a new Websocket server listening on given port
+ # with `max_clients` slots available
+ init with_infos(port: Int, max_clients: Int)
do
- listener = new TCPServer(port)
+ var listener = new TCPServer(port)
listener.listen max_clients
+ init(listener)
end
# Accepts an incoming connection
var client = listener.accept
assert client != null
- return new WebsocketConnection(listener.port, "", client)
+ return new WebsocketConnection(client)
end
- # Stop listening for incoming connections
+ # Close the server and the socket below it
fun close
do
listener.close
+ closed = true
end
end
#
# Can be used to communicate with a client
class WebsocketConnection
- super TCPStream
+ super DuplexProtocol
+ super PollableReader
+
+ redef type STREAM: TCPStream
+
+ # Does the current frame have a mask?
+ private var has_mask = false
+
+ # Mask with which to XOR input data
+ private var mask = new CString(4)
+
+ # Offset of the mask to use when decoding input data
+ private var mask_offset = -1
+
+ # Length of the current frame
+ private var frame_length = -1
+
+ # Position in current frame
+ private var frame_cursor = -1
+
+ # Type of the current frame
+ var frame_type = -1
+
+ # Is `self` closed?
+ var closed = false
init do
- _buffer = new FlatBuffer
- _buffer_pos = 0
var headers = parse_handshake
var resp = handshake_response(headers)
- client.write(resp)
+ origin.write(resp)
end
- # Client connection to the server
- var client: TCPStream
-
# Disconnect from a client
- redef fun close
- do
- client.close
+ redef fun close do
+ origin.close
+ closed = true
end
- # Parses the input handshake sent by the client
+ # Ping response message
+ private fun pong_msg: Bytes do return once b"\x8a\x00"
+
+ # Parse the input handshake sent by the client
# See RFC 6455 for information
private fun parse_handshake: Map[String,String]
do
return headmap
end
- # Generates the handshake
+ # Generate a handshake response
private fun handshake_response(heads: Map[String,String]): String
do
var resp_map = new HashMap[String,String]
resp_map["Connection:"] = "Upgrade"
var key = heads["Sec-WebSocket-Key"]
key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
- key = key.sha1.encode_base64
+ key = key.sha1.encode_base64.to_s
resp_map["Sec-WebSocket-Accept:"] = key
var resp = resp_map.join("\r\n", " ")
resp += "\r\n\r\n"
return resp
end
- # Frames a text message to be sent to a client
- private fun frame_message(msg: String): String
+ # Frame a text message to be sent to a client
+ private fun frame_message(msg: Text): Bytes
do
- var ans_buffer = new FlatBuffer
+ var ans_buffer = new Bytes.with_capacity(msg.byte_length + 2)
# Flag for final frame set to 1
# opcode set to 1 (for text)
- ans_buffer.add(129.ascii)
+ ans_buffer.add(129)
if msg.length < 126 then
- ans_buffer.add(msg.length.ascii)
+ ans_buffer.add(msg.length)
end
if msg.length >= 126 and msg.length <= 65535 then
- ans_buffer.add(126.ascii)
- ans_buffer.add(msg.length.rshift(8).ascii)
- ans_buffer.add(msg.length.ascii)
+ ans_buffer.add(126)
+ ans_buffer.add(msg.length >> 8)
+ ans_buffer.add(msg.length)
end
- ans_buffer.append(msg)
- return ans_buffer.to_s
+ msg.append_to_bytes(ans_buffer)
+ return ans_buffer
end
- # Reads an HTTP frame
+ # Read an HTTP frame
protected fun read_http_frame(buf: Buffer): String
do
- buf.append client.read_line
- buf.append("\r\n")
+ var ln = origin.read_line
+ buf.append ln
+ buf.append "\r\n"
if buf.has_suffix("\r\n\r\n") then return buf.to_s
return read_http_frame(buf)
end
- # Gets the message from the client, unpads it and reconstitutes the message
- private fun unpad_message do
- var fin = false
- while not fin do
- var fst_char = client.read_char
- var snd_char = client.read_char
- # First byte in msg is formatted this way :
- # |(fin - 1bit)|(RSV1 - 1bit)|(RSV2 - 1bit)|(RSV3 - 1bit)|(opcode - 4bits)
- # fin = Flag indicating if current frame is the last one
- # RSV1/2/3 = Extension flags, unsupported
- # Opcode values :
- # %x0 denotes a continuation frame
- # %x1 denotes a text frame
- # %x2 denotes a binary frame
- # %x3-7 are reserved for further non-control frames
- # %x8 denotes a connection close
- # %x9 denotes a ping
- # %xA denotes a pong
- # %xB-F are reserved for further control frames
- var fin_flag = fst_char.bin_and(128)
- if fin_flag != 0 then fin = true
- var opcode = fst_char.bin_and(15)
- if opcode == 9 then
- _buffer.add(138.ascii)
- _buffer.add('\0')
- client.write(_buffer.to_s)
- _buffer_pos += 2
+ # Get a frame's information
+ private fun read_frame_info do
+ var fst_byte = origin.read_byte
+ var snd_byte = origin.read_byte
+ if fst_byte < 0 or snd_byte < 0 then
+ last_error = new IOError("Error: bad frame")
+ close
+ return
+ end
+ # First byte in msg is formatted this way :
+ # |(fin - 1bit)|(RSV1 - 1bit)|(RSV2 - 1bit)|(RSV3 - 1bit)|(opcode - 4bits)
+ # fin = Flag indicating if current frame is the last one for the current message
+ # RSV1/2/3 = Extension flags, unsupported
+ # Opcode values :
+ # %x0 denotes a continuation frame
+ # %x1 denotes a text frame
+ # %x2 denotes a binary frame
+ # %x3-7 are reserved for further non-control frames
+ # %x8 denotes a connection close
+ # %x9 denotes a ping
+ # %xA denotes a pong
+ # %xB-F are reserved for further control frames
+ var opcode = fst_byte & 0b0000_1111
+ if opcode == 9 then
+ origin.write_bytes(pong_msg)
+ return
+ end
+ if opcode == 8 then
+ close
+ return
+ end
+ frame_type = opcode
+ # Second byte is formatted this way :
+ # |(mask - 1bit)|(payload length - 7 bits)
+ # As specified, if the payload length is 126 or 127
+ # The next 16 or 64 bits contain an extended payload length
+ var mask_flag = snd_byte & 0b1000_0000
+ var len = snd_byte & 0b0111_1111
+ var payload_ext_len = 0
+ if len == 126 then
+ var tmp = origin.read_bytes(2)
+ if tmp.length != 2 then
+ last_error = new IOError("Error: received interrupted frame")
+ origin.close
return
end
- if opcode == 8 then
- self.client.close
+ payload_ext_len += tmp[0].to_i << 8
+ payload_ext_len += tmp[1].to_i
+ else if len == 127 then
+ var tmp = origin.read_bytes(8)
+ if tmp.length != 8 then
+ last_error = new IOError("Error: received interrupted frame")
+ origin.close
return
end
- # Second byte is formatted this way :
- # |(mask - 1bit)|(payload length - 7 bits)
- # As specified, if the payload length is 126 or 127
- # The next 16 or 64 bits contain an extended payload length
- var mask_flag = snd_char.bin_and(128)
- var len = snd_char.bin_and(127)
- var payload_ext_len = 0
- if len == 126 then
- payload_ext_len = client.read_char.lshift(8)
- payload_ext_len += client.read_char
- else if len == 127 then
- # 64 bits for length are not supported,
- # only the last 32 will be interpreted as a Nit Integer
- for i in [0..4[ do client.read_char
- payload_ext_len = client.read_char.lshift(24)
- payload_ext_len += client.read_char.lshift(16)
- payload_ext_len += client.read_char.lshift(8)
- payload_ext_len += client.read_char
- end
- if mask_flag != 0 then
- if payload_ext_len != 0 then
- var msg = client.read(payload_ext_len+4)
- var mask = msg.substring(0,4)
- _buffer.append(unmask_message(mask, msg.substring(4, payload_ext_len)))
- else
- if len == 0 then
- return
- end
- var msg = client.read(len+4)
- var mask = msg.substring(0,4)
- _buffer.append(unmask_message(mask, msg.substring(4, len)))
- end
+ for i in [0 .. 8[ do
+ payload_ext_len += tmp[i].to_i << (8 * (7 - i))
end
end
+ if mask_flag != 0 then
+ origin.read_bytes_to_cstring(mask, 4)
+ has_mask = true
+ else
+ mask.memset(0, 4)
+ has_mask = false
+ end
+ if payload_ext_len != 0 then
+ len = payload_ext_len
+ end
+ frame_length = len
+ frame_cursor = 0
end
- # Unmasks a message sent by a client
- private fun unmask_message(key: String, message: String): String
- do
- var return_message = new FlatBuffer.with_capacity(message.length)
- var msg_iter = message.chars.iterator
-
- while msg_iter.is_ok do
- return_message.chars[msg_iter.index] = msg_iter.item.ascii.bin_xor(key.chars[msg_iter.index%4].ascii).ascii
- msg_iter.next
+ redef fun raw_read_byte do
+ while not closed and frame_cursor >= frame_length do
+ read_frame_info
end
+ if closed then return -1
+ var b = origin.read_byte
+ if b >= 0 then
+ frame_cursor += 1
+ end
+ return b
+ end
- return return_message.to_s
+ redef fun raw_read_bytes(ns, len) do
+ while not closed and frame_cursor >= frame_length do
+ read_frame_info
+ end
+ if closed then return -1
+ var available = frame_length - frame_cursor
+ var to_rd = len.min(available)
+ var rd = origin.read_bytes_to_cstring(ns, to_rd)
+ if rd < 0 then
+ close
+ return 0
+ end
+ if has_mask then
+ ns.xor(mask, rd, 4, mask_offset)
+ mask_offset = rd % 4
+ end
+ frame_cursor += rd
+ return rd
end
# Checks if a connection to a client is available
- redef fun connected do return client.connected
+ fun connected: Bool do return not closed and origin.connected
- redef fun write(msg)
- do
- client.write(frame_message(msg.to_s))
+ redef fun write_bytes_from_cstring(ns, len) do
+ origin.write_bytes(frame_message(ns.to_s_unsafe(len)))
end
- redef fun is_writable do return client.connected
-
- redef fun fill_buffer
- do
- _buffer.clear
- _buffer_pos = 0
- unpad_message
- end
+ redef fun write(msg) do origin.write_bytes(frame_message(msg))
- redef fun end_reached do return client._buffer_pos >= client._buffer.length and client.end_reached
+ redef fun is_writable do return origin.connected
# Is there some data available to be read ?
- fun can_read(timeout: Int): Bool do return client.ready_to_read(timeout)
+ fun can_read(timeout: Int): Bool do return not closed and origin.ready_to_read(timeout)
- redef fun poll_in do return client.poll_in
+ redef fun poll_in do return origin.poll_in
end