kopia lustrzana https://github.com/espressif/esp-idf
404 wiersze
19 KiB
Python
404 wiersze
19 KiB
Python
from __future__ import print_function, unicode_literals
|
|
|
|
import os
|
|
import random
|
|
import re
|
|
import select
|
|
import socket
|
|
import ssl
|
|
import string
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from itertools import count
|
|
from threading import Event, Lock, Thread
|
|
|
|
import paho.mqtt.client as mqtt
|
|
import ttfw_idf
|
|
from common_test_methods import get_host_ip4_by_dest_ip
|
|
|
|
DEFAULT_MSG_SIZE = 16
|
|
|
|
|
|
def _path(f):
|
|
return os.path.join(os.path.dirname(os.path.realpath(__file__)),f)
|
|
|
|
|
|
def set_server_cert_cn(ip):
|
|
arg_list = [
|
|
['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', '/CN={}'.format(ip), '-new'],
|
|
['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'),
|
|
'-CAkey', _path('ca.key'), '-CAcreateserial', '-out', _path('srv.crt'), '-days', '360']]
|
|
for args in arg_list:
|
|
if subprocess.check_call(args) != 0:
|
|
raise RuntimeError('openssl command {} failed'.format(args))
|
|
|
|
|
|
# Publisher class creating a python client to send/receive published data from esp-mqtt client
|
|
class MqttPublisher:
|
|
|
|
def __init__(self, dut, transport, qos, repeat, published, queue, publish_cfg, log_details=False):
|
|
# instance variables used as parameters of the publish test
|
|
self.event_stop_client = Event()
|
|
self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE))
|
|
self.client = None
|
|
self.dut = dut
|
|
self.log_details = log_details
|
|
self.repeat = repeat
|
|
self.publish_cfg = publish_cfg
|
|
self.publish_cfg['qos'] = qos
|
|
self.publish_cfg['queue'] = queue
|
|
self.publish_cfg['transport'] = transport
|
|
self.lock = Lock()
|
|
# static variables used to pass options to and from static callbacks of paho-mqtt client
|
|
MqttPublisher.event_client_connected = Event()
|
|
MqttPublisher.event_client_got_all = Event()
|
|
MqttPublisher.published = published
|
|
MqttPublisher.event_client_connected.clear()
|
|
MqttPublisher.event_client_got_all.clear()
|
|
MqttPublisher.expected_data = self.sample_string * self.repeat
|
|
|
|
def print_details(self, text):
|
|
if self.log_details:
|
|
print(text)
|
|
|
|
def mqtt_client_task(self, client, lock):
|
|
while not self.event_stop_client.is_set():
|
|
with lock:
|
|
client.loop()
|
|
time.sleep(0.001) # yield to other threads
|
|
|
|
# The callback for when the client receives a CONNACK response from the server (needs to be static)
|
|
@staticmethod
|
|
def on_connect(_client, _userdata, _flags, _rc):
|
|
MqttPublisher.event_client_connected.set()
|
|
|
|
# The callback for when a PUBLISH message is received from the server (needs to be static)
|
|
@staticmethod
|
|
def on_message(client, userdata, msg):
|
|
payload = msg.payload.decode()
|
|
if payload == MqttPublisher.expected_data:
|
|
userdata += 1
|
|
client.user_data_set(userdata)
|
|
if userdata == MqttPublisher.published:
|
|
MqttPublisher.event_client_got_all.set()
|
|
|
|
def __enter__(self):
|
|
|
|
qos = self.publish_cfg['qos']
|
|
queue = self.publish_cfg['queue']
|
|
transport = self.publish_cfg['transport']
|
|
broker_host = self.publish_cfg['broker_host_' + transport]
|
|
broker_port = self.publish_cfg['broker_port_' + transport]
|
|
|
|
# Start the test
|
|
self.print_details("PUBLISH TEST: transport:{}, qos:{}, sequence:{}, enqueue:{}, sample msg:'{}'"
|
|
.format(transport, qos, MqttPublisher.published, queue, MqttPublisher.expected_data))
|
|
|
|
try:
|
|
if transport in ['ws', 'wss']:
|
|
self.client = mqtt.Client(transport='websockets')
|
|
else:
|
|
self.client = mqtt.Client()
|
|
self.client.on_connect = MqttPublisher.on_connect
|
|
self.client.on_message = MqttPublisher.on_message
|
|
self.client.user_data_set(0)
|
|
|
|
if transport in ['ssl', 'wss']:
|
|
self.client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
|
|
self.client.tls_insecure_set(True)
|
|
self.print_details('Connecting...')
|
|
self.client.connect(broker_host, broker_port, 60)
|
|
except Exception:
|
|
self.print_details('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}'.format(broker_host))
|
|
raise
|
|
# Starting a py-client in a separate thread
|
|
thread1 = Thread(target=self.mqtt_client_task, args=(self.client, self.lock))
|
|
thread1.start()
|
|
self.print_details('Connecting py-client to broker {}:{}...'.format(broker_host, broker_port))
|
|
if not MqttPublisher.event_client_connected.wait(timeout=30):
|
|
raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_host))
|
|
with self.lock:
|
|
self.client.subscribe(self.publish_cfg['subscribe_topic'], qos)
|
|
self.dut.write(' '.join(str(x) for x in (transport, self.sample_string, self.repeat, MqttPublisher.published, qos, queue)), eol='\n')
|
|
try:
|
|
# waiting till subscribed to defined topic
|
|
self.dut.expect(re.compile(r'MQTT_EVENT_SUBSCRIBED'), timeout=30)
|
|
for _ in range(MqttPublisher.published):
|
|
with self.lock:
|
|
self.client.publish(self.publish_cfg['publish_topic'], self.sample_string * self.repeat, qos)
|
|
self.print_details('Publishing...')
|
|
self.print_details('Checking esp-client received msg published from py-client...')
|
|
self.dut.expect(re.compile(r'Correct pattern received exactly x times'), timeout=60)
|
|
if not MqttPublisher.event_client_got_all.wait(timeout=60):
|
|
raise ValueError('Not all data received from ESP32')
|
|
print(' - all data received from ESP32')
|
|
finally:
|
|
self.event_stop_client.set()
|
|
thread1.join()
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.client.disconnect()
|
|
self.event_stop_client.clear()
|
|
|
|
|
|
# Simple server for mqtt over TLS connection
|
|
class TlsServer:
|
|
|
|
def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False):
|
|
self.port = port
|
|
self.socket = socket.socket()
|
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
self.socket.settimeout(10.0)
|
|
self.shutdown = Event()
|
|
self.client_cert = client_cert
|
|
self.refuse_connection = refuse_connection
|
|
self.ssl_error = None
|
|
self.use_alpn = use_alpn
|
|
self.negotiated_protocol = None
|
|
|
|
def __enter__(self):
|
|
try:
|
|
self.socket.bind(('', self.port))
|
|
except socket.error as e:
|
|
print('Bind failed:{}'.format(e))
|
|
raise
|
|
|
|
self.socket.listen(1)
|
|
self.server_thread = Thread(target=self.run_server)
|
|
self.server_thread.start()
|
|
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.shutdown.set()
|
|
self.server_thread.join()
|
|
self.socket.close()
|
|
if (self.conn is not None):
|
|
self.conn.close()
|
|
|
|
def get_last_ssl_error(self):
|
|
return self.ssl_error
|
|
|
|
def get_negotiated_protocol(self):
|
|
return self.negotiated_protocol
|
|
|
|
def run_server(self):
|
|
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
if self.client_cert:
|
|
context.verify_mode = ssl.CERT_REQUIRED
|
|
context.load_verify_locations(cafile=_path('ca.crt'))
|
|
context.load_cert_chain(certfile=_path('srv.crt'), keyfile=_path('server.key'))
|
|
if self.use_alpn:
|
|
context.set_alpn_protocols(['mymqtt', 'http/1.1'])
|
|
self.socket = context.wrap_socket(self.socket, server_side=True)
|
|
try:
|
|
self.conn, address = self.socket.accept() # accept new connection
|
|
self.socket.settimeout(10.0)
|
|
print(' - connection from: {}'.format(address))
|
|
if self.use_alpn:
|
|
self.negotiated_protocol = self.conn.selected_alpn_protocol()
|
|
print(' - negotiated_protocol: {}'.format(self.negotiated_protocol))
|
|
self.handle_conn()
|
|
except ssl.SSLError as e:
|
|
self.conn = None
|
|
self.ssl_error = str(e)
|
|
print(' - SSLError: {}'.format(str(e)))
|
|
|
|
def handle_conn(self):
|
|
while not self.shutdown.is_set():
|
|
r,w,e = select.select([self.conn], [], [], 1)
|
|
try:
|
|
if self.conn in r:
|
|
self.process_mqtt_connect()
|
|
|
|
except socket.error as err:
|
|
print(' - error: {}'.format(err))
|
|
raise
|
|
|
|
def process_mqtt_connect(self):
|
|
try:
|
|
data = bytearray(self.conn.recv(1024))
|
|
message = ''.join(format(x, '02x') for x in data)
|
|
if message[0:16] == '101800044d515454':
|
|
if self.refuse_connection is False:
|
|
print(' - received mqtt connect, sending ACK')
|
|
self.conn.send(bytearray.fromhex('20020000'))
|
|
else:
|
|
# injecting connection not authorized error
|
|
print(' - received mqtt connect, sending NAK')
|
|
self.conn.send(bytearray.fromhex('20020005'))
|
|
else:
|
|
raise Exception(' - error process_mqtt_connect unexpected connect received: {}'.format(message))
|
|
finally:
|
|
# stop the server after the connect message in happy flow, or if any exception occur
|
|
self.shutdown.set()
|
|
|
|
|
|
def connection_tests(dut, cases, dut_ip):
|
|
ip = get_host_ip4_by_dest_ip(dut_ip)
|
|
set_server_cert_cn(ip)
|
|
server_port = 2222
|
|
|
|
def teardown_connection_suite():
|
|
dut.write('conn teardown 0 0')
|
|
|
|
def start_connection_case(case, desc):
|
|
print('Starting {}: {}'.format(case, desc))
|
|
case_id = cases[case]
|
|
dut.write('conn {} {} {}'.format(ip, server_port, case_id))
|
|
dut.expect('Test case:{} started'.format(case_id))
|
|
return case_id
|
|
|
|
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
|
|
# All these cases connect to the server with no server verification or with server only verification
|
|
with TlsServer(server_port):
|
|
test_nr = start_connection_case(case, 'default server - expect to connect normally')
|
|
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
|
|
with TlsServer(server_port, refuse_connection=True):
|
|
test_nr = start_connection_case(case, 'ssl shall connect, but mqtt sends connect refusal')
|
|
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
|
|
dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error
|
|
with TlsServer(server_port, client_cert=True) as s:
|
|
test_nr = start_connection_case(case, 'server with client verification - handshake error since client presents no client certificate')
|
|
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
|
|
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
|
|
if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' not in s.get_last_ssl_error():
|
|
raise RuntimeError('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
|
|
|
|
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
|
|
# These cases connect to server with both server and client verification (client key might be password protected)
|
|
with TlsServer(server_port, client_cert=True):
|
|
test_nr = start_connection_case(case, 'server with client verification - expect to connect normally')
|
|
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
|
|
|
|
case = 'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
|
|
with TlsServer(server_port) as s:
|
|
test_nr = start_connection_case(case, 'invalid server certificate on default server - expect ssl handshake error')
|
|
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
|
|
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
|
|
if 'alert unknown ca' not in s.get_last_ssl_error():
|
|
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
|
|
|
|
case = 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
|
|
with TlsServer(server_port, client_cert=True) as s:
|
|
test_nr = start_connection_case(case, 'Invalid client certificate on server with client verification - expect ssl handshake error')
|
|
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
|
|
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
|
|
if 'CERTIFICATE_VERIFY_FAILED' not in s.get_last_ssl_error():
|
|
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
|
|
|
|
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
|
|
with TlsServer(server_port, use_alpn=True) as s:
|
|
test_nr = start_connection_case(case, 'server with alpn - expect connect, check resolved protocol')
|
|
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
|
|
if case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None:
|
|
print(' - client with alpn off, no negotiated protocol: OK')
|
|
elif case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt':
|
|
print(' - client with alpn on, negotiated protocol resolved: OK')
|
|
else:
|
|
raise Exception('Unexpected negotiated protocol {}'.format(s.get_negotiated_protocol()))
|
|
|
|
teardown_connection_suite()
|
|
|
|
|
|
@ttfw_idf.idf_custom_test(env_tag='ethernet_router', group='test-apps')
|
|
def test_app_protocol_mqtt_publish_connect(env, extra_data):
|
|
"""
|
|
steps:
|
|
1. join AP
|
|
2. connect to uri specified in the config
|
|
3. send and receive data
|
|
"""
|
|
dut1 = env.get_dut('mqtt_publish_connect_test', 'tools/test_apps/protocols/mqtt/publish_connect_test')
|
|
# check and log bin size
|
|
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_publish_connect_test.bin')
|
|
bin_size = os.path.getsize(binary_file)
|
|
ttfw_idf.log_performance('mqtt_publish_connect_test_bin_size', '{}KB'.format(bin_size // 1024))
|
|
|
|
# Look for test case symbolic names and publish configs
|
|
cases = {}
|
|
publish_cfg = {}
|
|
try:
|
|
|
|
# Get connection test cases configuration: symbolic names for test cases
|
|
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT',
|
|
'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT',
|
|
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH',
|
|
'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT',
|
|
'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
|
|
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD',
|
|
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT',
|
|
'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
|
|
cases[case] = dut1.app.get_sdkconfig()[case]
|
|
except Exception:
|
|
print('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
|
|
raise
|
|
|
|
dut1.start_app()
|
|
esp_ip = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
|
|
print('Got IP={}'.format(esp_ip))
|
|
|
|
if not os.getenv('MQTT_SKIP_CONNECT_TEST'):
|
|
connection_tests(dut1,cases,esp_ip)
|
|
|
|
#
|
|
# start publish tests only if enabled in the environment (for weekend tests only)
|
|
if not os.getenv('MQTT_PUBLISH_TEST'):
|
|
return
|
|
|
|
# Get publish test configuration
|
|
try:
|
|
def get_host_port_from_dut(dut1, config_option):
|
|
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()[config_option])
|
|
if value is None:
|
|
return None, None
|
|
return value.group(1), int(value.group(2))
|
|
|
|
publish_cfg['publish_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_SUBSCRIBE_TOPIC'].replace('"','')
|
|
publish_cfg['subscribe_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_PUBLISH_TOPIC'].replace('"','')
|
|
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_SSL_URI')
|
|
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_TCP_URI')
|
|
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WS_URI')
|
|
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WSS_URI')
|
|
|
|
except Exception:
|
|
print('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
|
|
raise
|
|
|
|
def start_publish_case(transport, qos, repeat, published, queue):
|
|
print('Starting Publish test: transport:{}, qos:{}, nr_of_msgs:{}, msg_size:{}, enqueue:{}'
|
|
.format(transport, qos, published, repeat * DEFAULT_MSG_SIZE, queue))
|
|
with MqttPublisher(dut1, transport, qos, repeat, published, queue, publish_cfg):
|
|
pass
|
|
|
|
# Initialize message sizes and repeat counts (if defined in the environment)
|
|
messages = []
|
|
for i in count(0):
|
|
# Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x}
|
|
env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
|
|
if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']):
|
|
messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']})
|
|
continue
|
|
break
|
|
if not messages: # No message sizes present in the env - set defaults
|
|
messages = [{'len':0, 'repeat':5}, # zero-sized messages
|
|
{'len':2, 'repeat':10}, # short messages
|
|
{'len':200, 'repeat':3}, # long messages
|
|
{'len':20, 'repeat':50} # many medium sized
|
|
]
|
|
|
|
# Iterate over all publish message properties
|
|
for qos in [0, 1, 2]:
|
|
for transport in ['tcp', 'ssl', 'ws', 'wss']:
|
|
for q in [0, 1]:
|
|
if publish_cfg['broker_host_' + transport] is None:
|
|
print('Skipping transport: {}...'.format(transport))
|
|
continue
|
|
for msg in messages:
|
|
start_publish_case(transport, qos, msg['len'], msg['repeat'], q)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_app_protocol_mqtt_publish_connect(dut=ttfw_idf.ESP32QEMUDUT if sys.argv[1:] == ['qemu'] else ttfw_idf.ESP32DUT)
|