31 #ifndef WEBSOCKET_FRAME_HPP
32 #define WEBSOCKET_FRAME_HPP
34 #include "wscommon.hpp"
36 #include "btexception.hpp"
41 #include <arpa/inet.h>
49 static int64_t htonll(int64_t v) {
50 static int HOST_IS_LE = 0x1234;
51 if (HOST_IS_LE == 0x1234)
52 HOST_IS_LE = (htons(1) != 1);
54 union { uint32_t hv[2]; int64_t v; } u;
55 u.hv[0] = htonl(v >> 32);
56 u.hv[1] = htonl(v & 0x0FFFFFFFFULL);
61 #define ntohll(x) htonll(x)
70 class wserror :
public exception
76 PROTOCOL_VIOLATION = 2,
77 PAYLOAD_VIOLATION = 3,
78 INTERNAL_ENDPOINT_ERROR = 4,
88 explicit wserror(
const std::string& __arg,
int code = wserror::FATAL_ERROR)
96 virtual const char*
what()
const throw()
97 {
return msg.c_str(); }
102 virtual int code()
const {
return ecode; }
127 return ::rand_r(&seed);
156 template <
class rng_policy>
163 static const uint8_t BPB0_OPCODE = 0x0F;
164 static const uint8_t BPB0_RSV3 = 0x10;
165 static const uint8_t BPB0_RSV2 = 0x20;
166 static const uint8_t BPB0_RSV1 = 0x40;
167 static const uint8_t BPB0_FIN = 0x80;
168 static const uint8_t BPB1_PAYLOAD = 0x7F;
169 static const uint8_t BPB1_MASK = 0x80;
171 static const uint8_t BASIC_PAYLOAD_16BIT_CODE = 0x7E;
172 static const uint8_t BASIC_PAYLOAD_64BIT_CODE = 0x7F;
174 static const unsigned int BASIC_HEADER_LENGTH = 2;
175 static const unsigned int MAX_HEADER_LENGTH = 14;
176 static const uint8_t extended_header_length = 12;
177 static const uint64_t max_payload_size = 100000000;
185 : m_state(STATE_BASIC_HEADER)
186 , m_bytes_needed(BASIC_HEADER_LENGTH)
188 , m_payload(std::vector<unsigned char>())
199 return (m_state == STATE_READY);
206 m_state = STATE_BASIC_HEADER;
207 m_bytes_needed = BASIC_HEADER_LENGTH;
210 std::fill(m_header,m_header+MAX_HEADER_LENGTH,0);
222 void consume(std::istream &s) {
225 case STATE_BASIC_HEADER:
226 s.read(&m_header[BASIC_HEADER_LENGTH-m_bytes_needed],m_bytes_needed);
228 m_bytes_needed -= s.gcount();
230 if (m_bytes_needed == 0) {
231 process_basic_header();
233 validate_basic_header();
235 if (m_bytes_needed > 0) {
236 m_state = STATE_EXTENDED_HEADER;
238 process_extended_header();
240 if (m_bytes_needed == 0) {
241 m_state = STATE_READY;
245 m_state = STATE_PAYLOAD;
250 case STATE_EXTENDED_HEADER:
251 s.read(&m_header[get_header_len()-m_bytes_needed],m_bytes_needed);
253 m_bytes_needed -= s.gcount();
255 if (m_bytes_needed == 0) {
256 process_extended_header();
257 if (m_bytes_needed == 0) {
258 m_state = STATE_READY;
261 m_state = STATE_PAYLOAD;
266 s.read(reinterpret_cast<char *>(&m_payload[m_payload.size()-m_bytes_needed]),
269 m_bytes_needed -= s.gcount();
271 if (m_bytes_needed == 0) {
272 m_state = STATE_READY;
280 s.read(reinterpret_cast<char *>(&m_header[0]),1);
281 if (
int(static_cast<unsigned char>(m_header[0])) == 0x88) {
284 m_state = STATE_BASIC_HEADER;
287 }
while (s.gcount() > 0);
293 }
catch (
const std::exception & e) {
297 if (m_degraded ==
true) {
298 throw tracing::wserror(
"An error occurred while trying to gracefully recover from a less serious frame error.");
301 m_state = STATE_RECOVERY;
314 return std::string(m_header, get_header_len());
322 return std::string(m_payload.begin(), m_payload.end());
347 m_header[0] |= BPB0_FIN;
349 m_header[0] &= (0xFF ^ BPB0_FIN);
358 return frame::opcode::value(m_header[0] & BPB0_OPCODE);
366 if (opcode::reserved(op)) {
367 throw tracing::wserror(
"reserved opcode",tracing::wserror::PROTOCOL_VIOLATION);
370 if (opcode::invalid(op)) {
371 throw tracing::wserror(
"invalid opcode",tracing::wserror::PROTOCOL_VIOLATION);
374 if (
is_control() && get_basic_size() > limits::PAYLOAD_SIZE_BASIC) {
375 throw tracing::wserror(
"control frames can't have large payloads",tracing::wserror::PROTOCOL_VIOLATION);
378 m_header[0] &= (0xFF ^ BPB0_OPCODE);
388 m_header[1] |= BPB1_MASK;
389 generate_masking_key();
391 m_header[1] &= (0xFF ^ BPB1_MASK);
401 set_payload_helper(source.size());
403 std::copy(source.begin(),source.end(),m_payload.begin());
410 void set_payload(
const std::vector<unsigned char>& source) {
411 set_payload_helper(source.size());
413 std::copy(source.begin(),source.end(),m_payload.begin());
421 if (m_payload.size() == 0) {
422 return close::status::NO_STATUS;
424 return close::status::value(get_raw_close_code());
433 if (m_payload.size() > 2) {
436 return std::string();
446 uint64_t get_bytes_needed()
const {
447 return m_bytes_needed;
454 char* get_extended_header() {
455 return m_header+BASIC_HEADER_LENGTH;
457 unsigned int get_header_len()
const {
458 unsigned int temp = 2;
464 if (get_basic_size() == 126) {
466 }
else if (get_basic_size() == 127) {
473 char* get_masking_key() {
474 return &m_header[get_header_len()-4];
478 bool get_fin()
const {
479 return ((m_header[0] & BPB0_FIN) == BPB0_FIN);
481 bool get_rsv1()
const {
482 return ((m_header[0] & BPB0_RSV1) == BPB0_RSV1);
484 void set_rsv1(
bool b) {
486 m_header[0] |= BPB0_RSV1;
488 m_header[0] &= (0xFF ^ BPB0_RSV1);
492 bool get_rsv2()
const {
493 return ((m_header[0] & BPB0_RSV2) == BPB0_RSV2);
495 void set_rsv2(
bool b) {
497 m_header[0] |= BPB0_RSV2;
499 m_header[0] &= (0xFF ^ BPB0_RSV2);
503 bool get_rsv3()
const {
504 return ((m_header[0] & BPB0_RSV3) == BPB0_RSV3);
506 void set_rsv3(
bool b) {
508 m_header[0] |= BPB0_RSV3;
510 m_header[0] &= (0xFF ^ BPB0_RSV3);
514 bool get_masked()
const {
515 return ((m_header[1] & BPB1_MASK) == BPB1_MASK);
517 uint8_t get_basic_size()
const {
518 return m_header[1] & BPB1_PAYLOAD;
520 size_t get_payload_size()
const {
521 if (m_state != STATE_READY && m_state != STATE_PAYLOAD) {
523 throw "attempted to get payload size before reading full header";
526 return m_payload.size();
529 close::status::value get_close_status()
const {
530 if (get_payload_size() == 0) {
531 return close::status::NO_STATUS;
532 }
else if (get_payload_size() >= 2) {
533 char val[2] = { m_payload[0], m_payload[1] };
536 std::copy(val,val+
sizeof(code),&code);
539 return close::status::value(code);
541 return close::status::PROTOCOL_ERROR;
544 std::string get_close_msg()
const {
545 if (get_payload_size() > 2) {
546 uint32_t state = utf8_validator::UTF8_ACCEPT;
548 validate_utf8(&state,&codep,2);
549 if (state != utf8_validator::UTF8_ACCEPT) {
550 throw tracing::wserror(
"Invalid UTF-8 Data",tracing::wserror::PAYLOAD_VIOLATION);
552 return std::string(m_payload.begin()+2,m_payload.end());
554 return std::string();
558 void set_payload_helper(uint64_t s) {
559 if (s > max_payload_size) {
560 throw tracing::wserror(
"requested payload is over implementation defined limit",tracing::wserror::MESSAGE_TOO_BIG);
564 if (
is_control() && s > limits::PAYLOAD_SIZE_BASIC) {
565 throw tracing::wserror(
"control frames can't have large payloads",tracing::wserror::PROTOCOL_VIOLATION);
568 bool masked = get_masked();
570 if (s <= limits::PAYLOAD_SIZE_BASIC) {
572 }
else if (s <= limits::PAYLOAD_SIZE_EXTENDED) {
573 m_header[1] = BASIC_PAYLOAD_16BIT_CODE;
578 *
reinterpret_cast<uint16_t*
>(&m_header[BASIC_HEADER_LENGTH]) = htons(s);
579 }
else if (s <= limits::PAYLOAD_SIZE_JUMBO) {
580 m_header[1] = BASIC_PAYLOAD_64BIT_CODE;
581 *
reinterpret_cast<uint64_t*
>(&m_header[BASIC_HEADER_LENGTH]) = htonll(s);
583 throw tracing::wserror(
"payload size limit is 63 bits",tracing::wserror::PROTOCOL_VIOLATION);
587 m_header[1] |= BPB1_MASK;
593 void set_status(close::status::value status,
const std::string message =
"") {
595 if (close::status::invalid(status)) {
596 std::stringstream err;
597 err <<
"Status code " << status <<
" is invalid";
601 if (close::status::reserved(status)) {
602 std::stringstream err;
603 err <<
"Status code " << status <<
" is reserved";
607 m_payload.resize(2+message.size());
611 *
reinterpret_cast<uint16_t*
>(&val[0]) = htons(status);
613 bool masked = get_masked();
615 m_header[1] = message.size()+2;
618 m_header[1] |= BPB1_MASK;
621 m_payload[0] = val[0];
622 m_payload[1] = val[1];
624 std::copy(message.begin(),message.end(),m_payload.begin()+2);
627 std::string print_frame()
const {
630 unsigned int len = get_header_len();
634 for (
unsigned int i = 0; i < len; i++) {
635 f << std::hex << (
unsigned short)m_header[i] <<
" ";
638 if (m_payload.size() > 50) {
639 f <<
"[payload of " << m_payload.size() <<
" bytes]";
641 std::vector<unsigned char>::const_iterator it;
642 for (it = m_payload.begin(); it != m_payload.end(); it++) {
650 void process_basic_header() {
651 m_bytes_needed = get_header_len() - BASIC_HEADER_LENGTH;
653 void process_extended_header() {
654 uint8_t s = get_basic_size();
655 uint64_t payload_size;
656 int mask_index = BASIC_HEADER_LENGTH;
658 if (s <= limits::PAYLOAD_SIZE_BASIC) {
660 }
else if (s == BASIC_PAYLOAD_16BIT_CODE) {
663 payload_size = ntohs(*(
664 reinterpret_cast<uint16_t*>(&m_header[BASIC_HEADER_LENGTH])
667 if (payload_size < s) {
668 std::stringstream err;
669 err <<
"payload length not minimally encoded. Using 16 bit form for payload size: " << payload_size;
670 m_bytes_needed = payload_size;
675 }
else if (s == BASIC_PAYLOAD_64BIT_CODE) {
678 payload_size = ntohll(*(
679 reinterpret_cast<uint64_t*>(&m_header[BASIC_HEADER_LENGTH])
682 if (payload_size <= limits::PAYLOAD_SIZE_EXTENDED) {
683 m_bytes_needed = payload_size;
685 tracing::wserror::PROTOCOL_VIOLATION);
691 throw tracing::wserror(
"invalid get_basic_size in process_extended_header");
694 if (get_masked() == 0) {
706 if (payload_size > max_payload_size) {
711 throw "Got frame with payload greater than maximum frame buffer size.";
713 m_payload.resize(payload_size);
714 m_bytes_needed = payload_size;
717 void process_payload() {
719 char *masking_key = get_masking_key();
721 for (uint64_t i = 0; i < m_payload.size(); i++) {
722 m_payload[i] = (m_payload[i] ^ masking_key[i%4]);
728 void process_payload2() {
753 void validate_utf8(uint32_t* state,uint32_t* codep,
size_t offset = 0)
const {
754 for (
size_t i = offset; i < m_payload.size(); i++) {
755 using utf8_validator::decode;
757 if (decode(state,codep,m_payload[i]) == utf8_validator::UTF8_REJECT) {
758 throw tracing::wserror(
"Invalid UTF-8 Data",tracing::wserror::PAYLOAD_VIOLATION);
762 void validate_basic_header()
const {
764 if (
is_control() && get_basic_size() > limits::PAYLOAD_SIZE_BASIC) {
765 throw tracing::wserror(
"Control Frame is too large",tracing::wserror::PROTOCOL_VIOLATION);
769 if (get_rsv1() || get_rsv2() || get_rsv3()) {
770 throw tracing::wserror(
"Reserved bit used",tracing::wserror::PROTOCOL_VIOLATION);
775 throw tracing::wserror(
"Reserved opcode used",tracing::wserror::PROTOCOL_VIOLATION);
780 throw tracing::wserror(
"Fragmented control message",tracing::wserror::PROTOCOL_VIOLATION);
784 void generate_masking_key() {
785 *(
reinterpret_cast<int32_t *
>(&m_header[get_header_len()-4])) = m_rng.gen();
787 void clear_masking_key() {
795 uint16_t get_raw_close_code()
const {
796 if (m_payload.size() <= 1) {
797 throw tracing::wserror(
"get_raw_close_code called with invalid size",tracing::wserror::FATAL_ERROR);
800 union {uint16_t i;
char c[2];} val;
802 val.c[0] = m_payload[0];
803 val.c[1] = m_payload[1];
808 static const uint8_t STATE_BASIC_HEADER = 1;
809 static const uint8_t STATE_EXTENDED_HEADER = 2;
810 static const uint8_t STATE_PAYLOAD = 3;
811 static const uint8_t STATE_READY = 4;
812 static const uint8_t STATE_RECOVERY = 5;
815 uint64_t m_bytes_needed;
818 char m_header[MAX_HEADER_LENGTH];
819 std::vector<unsigned char> m_payload;
827 #endif // WEBSOCKET_FRAME_HPP