175 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
			
		
		
	
	
			175 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
// SPDX-License-Identifier: GPL-2.0
 | 
						|
/* Copyright (c) 2020, Tessares SA. */
 | 
						|
/* Copyright (c) 2022, SUSE. */
 | 
						|
 | 
						|
#include <test_progs.h>
 | 
						|
#include "cgroup_helpers.h"
 | 
						|
#include "network_helpers.h"
 | 
						|
#include "mptcp_sock.skel.h"
 | 
						|
 | 
						|
#ifndef TCP_CA_NAME_MAX
 | 
						|
#define TCP_CA_NAME_MAX	16
 | 
						|
#endif
 | 
						|
 | 
						|
struct mptcp_storage {
 | 
						|
	__u32 invoked;
 | 
						|
	__u32 is_mptcp;
 | 
						|
	struct sock *sk;
 | 
						|
	__u32 token;
 | 
						|
	struct sock *first;
 | 
						|
	char ca_name[TCP_CA_NAME_MAX];
 | 
						|
};
 | 
						|
 | 
						|
static int verify_tsk(int map_fd, int client_fd)
 | 
						|
{
 | 
						|
	int err, cfd = client_fd;
 | 
						|
	struct mptcp_storage val;
 | 
						|
 | 
						|
	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
 | 
						|
	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
 | 
						|
		return err;
 | 
						|
 | 
						|
	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
 | 
						|
		err++;
 | 
						|
 | 
						|
	if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
 | 
						|
		err++;
 | 
						|
 | 
						|
	return err;
 | 
						|
}
 | 
						|
 | 
						|
static void get_msk_ca_name(char ca_name[])
 | 
						|
{
 | 
						|
	size_t len;
 | 
						|
	int fd;
 | 
						|
 | 
						|
	fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
 | 
						|
	if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
 | 
						|
		return;
 | 
						|
 | 
						|
	len = read(fd, ca_name, TCP_CA_NAME_MAX);
 | 
						|
	if (!ASSERT_GT(len, 0, "failed to read ca_name"))
 | 
						|
		goto err;
 | 
						|
 | 
						|
	if (len > 0 && ca_name[len - 1] == '\n')
 | 
						|
		ca_name[len - 1] = '\0';
 | 
						|
 | 
						|
err:
 | 
						|
	close(fd);
 | 
						|
}
 | 
						|
 | 
						|
static int verify_msk(int map_fd, int client_fd, __u32 token)
 | 
						|
{
 | 
						|
	char ca_name[TCP_CA_NAME_MAX];
 | 
						|
	int err, cfd = client_fd;
 | 
						|
	struct mptcp_storage val;
 | 
						|
 | 
						|
	if (!ASSERT_GT(token, 0, "invalid token"))
 | 
						|
		return -1;
 | 
						|
 | 
						|
	get_msk_ca_name(ca_name);
 | 
						|
 | 
						|
	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
 | 
						|
	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
 | 
						|
		return err;
 | 
						|
 | 
						|
	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
 | 
						|
		err++;
 | 
						|
 | 
						|
	if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
 | 
						|
		err++;
 | 
						|
 | 
						|
	if (!ASSERT_EQ(val.token, token, "unexpected token"))
 | 
						|
		err++;
 | 
						|
 | 
						|
	if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
 | 
						|
		err++;
 | 
						|
 | 
						|
	if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
 | 
						|
		err++;
 | 
						|
 | 
						|
	return err;
 | 
						|
}
 | 
						|
 | 
						|
static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
 | 
						|
{
 | 
						|
	int client_fd, prog_fd, map_fd, err;
 | 
						|
	struct mptcp_sock *sock_skel;
 | 
						|
 | 
						|
	sock_skel = mptcp_sock__open_and_load();
 | 
						|
	if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
 | 
						|
		return -EIO;
 | 
						|
 | 
						|
	err = mptcp_sock__attach(sock_skel);
 | 
						|
	if (!ASSERT_OK(err, "skel_attach"))
 | 
						|
		goto out;
 | 
						|
 | 
						|
	prog_fd = bpf_program__fd(sock_skel->progs._sockops);
 | 
						|
	if (!ASSERT_GE(prog_fd, 0, "bpf_program__fd")) {
 | 
						|
		err = -EIO;
 | 
						|
		goto out;
 | 
						|
	}
 | 
						|
 | 
						|
	map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
 | 
						|
	if (!ASSERT_GE(map_fd, 0, "bpf_map__fd")) {
 | 
						|
		err = -EIO;
 | 
						|
		goto out;
 | 
						|
	}
 | 
						|
 | 
						|
	err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
 | 
						|
	if (!ASSERT_OK(err, "bpf_prog_attach"))
 | 
						|
		goto out;
 | 
						|
 | 
						|
	client_fd = connect_to_fd(server_fd, 0);
 | 
						|
	if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
 | 
						|
		err = -EIO;
 | 
						|
		goto out;
 | 
						|
	}
 | 
						|
 | 
						|
	err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
 | 
						|
			  verify_tsk(map_fd, client_fd);
 | 
						|
 | 
						|
	close(client_fd);
 | 
						|
 | 
						|
out:
 | 
						|
	mptcp_sock__destroy(sock_skel);
 | 
						|
	return err;
 | 
						|
}
 | 
						|
 | 
						|
static void test_base(void)
 | 
						|
{
 | 
						|
	int server_fd, cgroup_fd;
 | 
						|
 | 
						|
	cgroup_fd = test__join_cgroup("/mptcp");
 | 
						|
	if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
 | 
						|
		return;
 | 
						|
 | 
						|
	/* without MPTCP */
 | 
						|
	server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
 | 
						|
	if (!ASSERT_GE(server_fd, 0, "start_server"))
 | 
						|
		goto with_mptcp;
 | 
						|
 | 
						|
	ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
 | 
						|
 | 
						|
	close(server_fd);
 | 
						|
 | 
						|
with_mptcp:
 | 
						|
	/* with MPTCP */
 | 
						|
	server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
 | 
						|
	if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
 | 
						|
		goto close_cgroup_fd;
 | 
						|
 | 
						|
	ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
 | 
						|
 | 
						|
	close(server_fd);
 | 
						|
 | 
						|
close_cgroup_fd:
 | 
						|
	close(cgroup_fd);
 | 
						|
}
 | 
						|
 | 
						|
void test_mptcp(void)
 | 
						|
{
 | 
						|
	if (test__start_subtest("base"))
 | 
						|
		test_base();
 | 
						|
}
 |