Commit a0019cd7 authored by Jiri Olsa's avatar Jiri Olsa Committed by Alexei Starovoitov
Browse files

lib/sort: Add priv pointer to swap function



Adding support to have priv pointer in swap callback function.

Following the initial change on cmp callback functions [1]
and adding SWAP_WRAPPER macro to identify sort call of sort_r.

Signed-off-by: default avatarJiri Olsa <jolsa@kernel.org>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Reviewed-by: default avatarMasami Hiramatsu <mhiramat@kernel.org>
Link: https://lore.kernel.org/bpf/20220316122419.933957-2-jolsa@kernel.org

[1] 4333fb96 ("media: lib/sort.c: implement sort() variant taking context argument")
parent 245d9496
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@

void sort_r(void *base, size_t num, size_t size,
	    cmp_r_func_t cmp_func,
	    swap_func_t swap_func,
	    swap_r_func_t swap_func,
	    const void *priv);

void sort(void *base, size_t num, size_t size,
+1 −0
Original line number Diff line number Diff line
@@ -226,6 +226,7 @@ struct callback_head {
typedef void (*rcu_callback_t)(struct rcu_head *head);
typedef void (*call_rcu_func_t)(struct rcu_head *head, rcu_callback_t func);

typedef void (*swap_r_func_t)(void *a, void *b, int size, const void *priv);
typedef void (*swap_func_t)(void *a, void *b, int size);

typedef int (*cmp_r_func_t)(const void *a, const void *b, const void *priv);
+30 −10
Original line number Diff line number Diff line
@@ -122,16 +122,27 @@ static void swap_bytes(void *a, void *b, size_t n)
 * a pointer, but small integers make for the smallest compare
 * instructions.
 */
#define SWAP_WORDS_64 (swap_func_t)0
#define SWAP_WORDS_32 (swap_func_t)1
#define SWAP_BYTES    (swap_func_t)2
#define SWAP_WORDS_64 (swap_r_func_t)0
#define SWAP_WORDS_32 (swap_r_func_t)1
#define SWAP_BYTES    (swap_r_func_t)2
#define SWAP_WRAPPER  (swap_r_func_t)3

struct wrapper {
	cmp_func_t cmp;
	swap_func_t swap;
};

/*
 * The function pointer is last to make tail calls most efficient if the
 * compiler decides not to inline this function.
 */
static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func)
static void do_swap(void *a, void *b, size_t size, swap_r_func_t swap_func, const void *priv)
{
	if (swap_func == SWAP_WRAPPER) {
		((const struct wrapper *)priv)->swap(a, b, (int)size);
		return;
	}

	if (swap_func == SWAP_WORDS_64)
		swap_words_64(a, b, size);
	else if (swap_func == SWAP_WORDS_32)
@@ -139,7 +150,7 @@ static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func)
	else if (swap_func == SWAP_BYTES)
		swap_bytes(a, b, size);
	else
		swap_func(a, b, (int)size);
		swap_func(a, b, (int)size, priv);
}

#define _CMP_WRAPPER ((cmp_r_func_t)0L)
@@ -147,7 +158,7 @@ static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func)
static int do_cmp(const void *a, const void *b, cmp_r_func_t cmp, const void *priv)
{
	if (cmp == _CMP_WRAPPER)
		return ((cmp_func_t)(priv))(a, b);
		return ((const struct wrapper *)priv)->cmp(a, b);
	return cmp(a, b, priv);
}

@@ -198,7 +209,7 @@ static size_t parent(size_t i, unsigned int lsbit, size_t size)
 */
void sort_r(void *base, size_t num, size_t size,
	    cmp_r_func_t cmp_func,
	    swap_func_t swap_func,
	    swap_r_func_t swap_func,
	    const void *priv)
{
	/* pre-scale counters for performance */
@@ -208,6 +219,10 @@ void sort_r(void *base, size_t num, size_t size,
	if (!a)		/* num < 2 || size == 0 */
		return;

	/* called from 'sort' without swap function, let's pick the default */
	if (swap_func == SWAP_WRAPPER && !((struct wrapper *)priv)->swap)
		swap_func = NULL;

	if (!swap_func) {
		if (is_aligned(base, size, 8))
			swap_func = SWAP_WORDS_64;
@@ -230,7 +245,7 @@ void sort_r(void *base, size_t num, size_t size,
		if (a)			/* Building heap: sift down --a */
			a -= size;
		else if (n -= size)	/* Sorting: Extract root to --n */
			do_swap(base, base + n, size, swap_func);
			do_swap(base, base + n, size, swap_func, priv);
		else			/* Sort complete */
			break;

@@ -257,7 +272,7 @@ void sort_r(void *base, size_t num, size_t size,
		c = b;			/* Where "a" belongs */
		while (b != a) {	/* Shift it into place */
			b = parent(b, lsbit, size);
			do_swap(base + b, base + c, size, swap_func);
			do_swap(base + b, base + c, size, swap_func, priv);
		}
	}
}
@@ -267,6 +282,11 @@ void sort(void *base, size_t num, size_t size,
	  cmp_func_t cmp_func,
	  swap_func_t swap_func)
{
	return sort_r(base, num, size, _CMP_WRAPPER, swap_func, cmp_func);
	struct wrapper w = {
		.cmp  = cmp_func,
		.swap = swap_func,
	};

	return sort_r(base, num, size, _CMP_WRAPPER, SWAP_WRAPPER, &w);
}
EXPORT_SYMBOL(sort);