#include "test/jemalloc_test.h"
#include "test/arena_util.h"
#include "test/san.h"

#include "jemalloc/internal/cache_bin.h"
#include "jemalloc/internal/san.h"
#include "jemalloc/internal/safety_check.h"

const char *malloc_conf = TEST_SAN_UAF_ALIGN_ENABLE;

static size_t san_uaf_align;

static bool fake_abort_called;
void fake_abort(const char *message) {
	(void)message;
	fake_abort_called = true;
}

static void
test_write_after_free_pre(void) {
	safety_check_set_abort(&fake_abort);
	fake_abort_called = false;
}

static void
test_write_after_free_post(void) {
	assert_d_eq(mallctl("thread.tcache.flush", NULL, NULL, NULL, 0),
	    0, "Unexpected tcache flush failure");
	expect_true(fake_abort_called, "Use-after-free check didn't fire.");
	safety_check_set_abort(NULL);
}

static bool
uaf_detection_enabled(void) {
	if (!config_uaf_detection || !san_uaf_detection_enabled()) {
		return false;
	}

	ssize_t lg_san_uaf_align;
	size_t sz = sizeof(lg_san_uaf_align);
	assert_d_eq(mallctl("opt.lg_san_uaf_align", &lg_san_uaf_align, &sz,
	    NULL, 0), 0, "Unexpected mallctl failure");
	if (lg_san_uaf_align < 0) {
		return false;
	}
	assert_zd_ge(lg_san_uaf_align, LG_PAGE, "san_uaf_align out of range");
	san_uaf_align = (size_t)1 << lg_san_uaf_align;

	bool tcache_enabled;
	sz = sizeof(tcache_enabled);
	assert_d_eq(mallctl("thread.tcache.enabled", &tcache_enabled, &sz, NULL,
	    0), 0, "Unexpected mallctl failure");
	if (!tcache_enabled) {
		return false;
	}

	return true;
}

static size_t
read_tcache_stashed_bytes(unsigned arena_ind) {
	if (!config_stats) {
		return 0;
	}

	uint64_t epoch;
	assert_d_eq(mallctl("epoch", NULL, NULL, (void *)&epoch, sizeof(epoch)),
	    0, "Unexpected mallctl() failure");

	size_t tcache_stashed_bytes;
	size_t sz = sizeof(tcache_stashed_bytes);
	assert_d_eq(mallctl(
	    "stats.arenas." STRINGIFY(MALLCTL_ARENAS_ALL)
	    ".tcache_stashed_bytes", &tcache_stashed_bytes, &sz, NULL, 0), 0,
	    "Unexpected mallctl failure");

	return tcache_stashed_bytes;
}

static void
test_use_after_free(size_t alloc_size, bool write_after_free) {
	void *ptr = (void *)(uintptr_t)san_uaf_align;
	assert_true(cache_bin_nonfast_aligned(ptr), "Wrong alignment");
	ptr = (void *)((uintptr_t)123 * (uintptr_t)san_uaf_align);
	assert_true(cache_bin_nonfast_aligned(ptr), "Wrong alignment");
	ptr = (void *)((uintptr_t)san_uaf_align + 1);
	assert_false(cache_bin_nonfast_aligned(ptr), "Wrong alignment");

	/*
	 * Disable purging (-1) so that all dirty pages remain committed, to
	 * make use-after-free tolerable.
	 */
	unsigned arena_ind = do_arena_create(-1, -1);
	int flags = MALLOCX_ARENA(arena_ind) | MALLOCX_TCACHE_NONE;

	size_t n_max = san_uaf_align * 2;
	void **items = mallocx(n_max * sizeof(void *), flags);
	assert_ptr_not_null(items, "Unexpected mallocx failure");

	bool found = false;
	size_t iter = 0;
	char magic = 's';
	assert_d_eq(mallctl("thread.tcache.flush", NULL, NULL, NULL, 0),
	    0, "Unexpected tcache flush failure");
	while (!found) {
		ptr = mallocx(alloc_size, flags);
		assert_ptr_not_null(ptr, "Unexpected mallocx failure");

		found = cache_bin_nonfast_aligned(ptr);
		*(char *)ptr = magic;
		items[iter] = ptr;
		assert_zu_lt(iter++, n_max, "No aligned ptr found");
	}

	if (write_after_free) {
		test_write_after_free_pre();
	}
	bool junked = false;
	while (iter-- != 0) {
		char *volatile mem = items[iter];
		assert_c_eq(*mem, magic, "Unexpected memory content");
		size_t stashed_before = read_tcache_stashed_bytes(arena_ind);
		free(mem);
		if (*mem != magic) {
			junked = true;
			assert_c_eq(*mem, (char)uaf_detect_junk,
			    "Unexpected junk-filling bytes");
			if (write_after_free) {
				*(char *)mem = magic + 1;
			}

			size_t stashed_after = read_tcache_stashed_bytes(
			    arena_ind);
			/*
			 * An edge case is the deallocation above triggering the
			 * tcache GC event, in which case the stashed pointers
			 * may get flushed immediately, before returning from
			 * free().  Treat these cases as checked already.
			 */
			if (stashed_after <= stashed_before) {
				fake_abort_called = true;
			}
		}
		/* Flush tcache (including stashed). */
		assert_d_eq(mallctl("thread.tcache.flush", NULL, NULL, NULL, 0),
		    0, "Unexpected tcache flush failure");
	}
	expect_true(junked, "Aligned ptr not junked");
	if (write_after_free) {
		test_write_after_free_post();
	}

	dallocx(items, flags);
	do_arena_destroy(arena_ind);
}

TEST_BEGIN(test_read_after_free) {
	test_skip_if(!uaf_detection_enabled());

	test_use_after_free(sizeof(void *), /* write_after_free */ false);
	test_use_after_free(sizeof(void *) + 1, /* write_after_free */ false);
	test_use_after_free(16, /* write_after_free */ false);
	test_use_after_free(20, /* write_after_free */ false);
	test_use_after_free(32, /* write_after_free */ false);
	test_use_after_free(33, /* write_after_free */ false);
	test_use_after_free(48, /* write_after_free */ false);
	test_use_after_free(64, /* write_after_free */ false);
	test_use_after_free(65, /* write_after_free */ false);
	test_use_after_free(129, /* write_after_free */ false);
	test_use_after_free(255, /* write_after_free */ false);
	test_use_after_free(256, /* write_after_free */ false);
}
TEST_END

TEST_BEGIN(test_write_after_free) {
	test_skip_if(!uaf_detection_enabled());

	test_use_after_free(sizeof(void *), /* write_after_free */ true);
	test_use_after_free(sizeof(void *) + 1, /* write_after_free */ true);
	test_use_after_free(16, /* write_after_free */ true);
	test_use_after_free(20, /* write_after_free */ true);
	test_use_after_free(32, /* write_after_free */ true);
	test_use_after_free(33, /* write_after_free */ true);
	test_use_after_free(48, /* write_after_free */ true);
	test_use_after_free(64, /* write_after_free */ true);
	test_use_after_free(65, /* write_after_free */ true);
	test_use_after_free(129, /* write_after_free */ true);
	test_use_after_free(255, /* write_after_free */ true);
	test_use_after_free(256, /* write_after_free */ true);
}
TEST_END

static bool
check_allocated_intact(void **allocated, size_t n_alloc) {
	for (unsigned i = 0; i < n_alloc; i++) {
		void *ptr = *(void **)allocated[i];
		bool found = false;
		for (unsigned j = 0; j < n_alloc; j++) {
			if (ptr == allocated[j]) {
				found = true;
				break;
			}
		}
		if (!found) {
			return false;
		}
	}

	return true;
}

TEST_BEGIN(test_use_after_free_integration) {
	test_skip_if(!uaf_detection_enabled());

	unsigned arena_ind = do_arena_create(-1, -1);
	int flags = MALLOCX_ARENA(arena_ind);

	size_t n_alloc = san_uaf_align * 2;
	void **allocated = mallocx(n_alloc * sizeof(void *), flags);
	assert_ptr_not_null(allocated, "Unexpected mallocx failure");

	for (unsigned i = 0; i < n_alloc; i++) {
		allocated[i] = mallocx(sizeof(void *) * 8, flags);
		assert_ptr_not_null(allocated[i], "Unexpected mallocx failure");
		if (i > 0) {
			/* Emulate a circular list. */
			*(void **)allocated[i] = allocated[i - 1];
		}
	}
	*(void **)allocated[0] = allocated[n_alloc - 1];
	expect_true(check_allocated_intact(allocated, n_alloc),
	    "Allocated data corrupted");

	for (unsigned i = 0; i < n_alloc; i++) {
		free(allocated[i]);
	}
	/* Read-after-free */
	expect_false(check_allocated_intact(allocated, n_alloc),
	    "Junk-filling not detected");

	test_write_after_free_pre();
	for (unsigned i = 0; i < n_alloc; i++) {
		allocated[i] = mallocx(sizeof(void *), flags);
		assert_ptr_not_null(allocated[i], "Unexpected mallocx failure");
		*(void **)allocated[i] = (void *)(uintptr_t)i;
	}
	/* Write-after-free */
	for (unsigned i = 0; i < n_alloc; i++) {
		free(allocated[i]);
		*(void **)allocated[i] = NULL;
	}
	test_write_after_free_post();
}
TEST_END

int
main(void) {
	return test(
	    test_read_after_free,
	    test_write_after_free,
	    test_use_after_free_integration);
}