/*	$NetBSD: if_laggproto.c,v 1.16 2024/09/26 06:08:24 rin Exp $	*/

/*-
 * SPDX-License-Identifier: BSD-2-Clause-NetBSD
 *
 * Copyright (c)2021 Internet Initiative Japan, Inc.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#include <sys/cdefs.h>
__KERNEL_RCSID(0, "$NetBSD: if_laggproto.c,v 1.16 2024/09/26 06:08:24 rin Exp $");

#include <sys/param.h>
#include <sys/types.h>

#include <sys/evcnt.h>
#include <sys/kmem.h>
#include <sys/mbuf.h>
#include <sys/mutex.h>
#include <sys/pslist.h>
#include <sys/syslog.h>
#include <sys/workqueue.h>

#include <net/if.h>
#include <net/if_ether.h>
#include <net/if_media.h>

#include <net/lagg/if_lagg.h>
#include <net/lagg/if_laggproto.h>

struct lagg_proto_softc {
	struct lagg_softc	*psc_softc;
	struct pslist_head	 psc_ports;
	kmutex_t		 psc_lock;
	pserialize_t		 psc_psz;
	size_t			 psc_ctxsiz;
	void			*psc_ctx;
	size_t			 psc_nactports;
	struct workqueue	*psc_workq;
	struct lagg_work	 psc_work_linkspeed;
};

/*
 * Locking notes:
 * - Items of struct lagg_proto_softc is protected by
 *   psc_lock (an adaptive mutex)
 * - psc_ports is protected by pselialize (psc_psz) and
 *   it updates exclusively by LAGG_PROTO_LOCK.
 * - Other locking notes are described in if_laggproto.h
 */

struct lagg_failover {
	bool		 fo_rx_all;
};

struct lagg_portmap {
	struct lagg_port	*pm_ports[LAGG_MAX_PORTS];
	size_t			 pm_nports;
};

struct lagg_portmaps {
	struct lagg_portmap	 maps_pmap[2];
	size_t			 maps_activepmap;
};

struct lagg_lb {
	struct lagg_portmaps	 lb_pmaps;
};

struct lagg_proto_port {
	struct pslist_entry	 lpp_entry;
	struct lagg_port	*lpp_laggport;
	uint64_t		 lpp_linkspeed;
	bool			 lpp_active;
	bool			 lpp_running;
};

#define LAGG_PROTO_LOCK(_psc)	mutex_enter(&(_psc)->psc_lock)
#define LAGG_PROTO_UNLOCK(_psc)	mutex_exit(&(_psc)->psc_lock)
#define LAGG_PROTO_LOCKED(_psc)	mutex_owned(&(_psc)->psc_lock)

static struct lagg_proto_softc *
		lagg_proto_alloc(lagg_proto, struct lagg_softc *);
static void	lagg_proto_free(struct lagg_proto_softc *);
static void	lagg_proto_insert_port(struct lagg_proto_softc *,
		    struct lagg_proto_port *);
static void	lagg_proto_remove_port(struct lagg_proto_softc *,
		    struct lagg_proto_port *);
static struct lagg_port *
		lagg_link_active(struct lagg_proto_softc *psc,
		    struct lagg_proto_port *, struct psref *);
static void	lagg_fail_linkspeed_work(struct lagg_work *, void *);
static void	lagg_lb_linkspeed_work(struct lagg_work*,
		    void *);
static void	lagg_common_linkstate(struct lagg_proto_softc *,
		    struct lagg_port *);

static inline struct lagg_portmap *
lagg_portmap_active(struct lagg_portmaps *maps)
{
	size_t i;

	i = atomic_load_consume(&maps->maps_activepmap);

	return &maps->maps_pmap[i];
}

static inline struct lagg_portmap *
lagg_portmap_next(struct lagg_portmaps *maps)
{
	size_t i;

	i = atomic_load_consume(&maps->maps_activepmap);
	i ^= 0x1;

	return &maps->maps_pmap[i];
}

static inline void
lagg_portmap_switch(struct lagg_portmaps *maps)
{
	size_t i;

	i = atomic_load_consume(&maps->maps_activepmap);
	i &= 0x1;
	i ^= 0x1;

	atomic_store_release(&maps->maps_activepmap, i);
}

static struct lagg_proto_softc *
lagg_proto_alloc(lagg_proto pr, struct lagg_softc *sc)
{
	struct lagg_proto_softc *psc;
	char xnamebuf[MAXCOMLEN];
	size_t ctxsiz;

	switch (pr) {
	case LAGG_PROTO_FAILOVER:
		ctxsiz = sizeof(struct lagg_failover);
		break;
	case LAGG_PROTO_LOADBALANCE:
		ctxsiz = sizeof(struct lagg_lb);
		break;
	default:
		ctxsiz = 0;
	}

	psc = kmem_zalloc(sizeof(*psc), KM_NOSLEEP);
	if (psc == NULL)
		return NULL;

	snprintf(xnamebuf, sizeof(xnamebuf), "%s.proto",
	    sc->sc_if.if_xname);
	psc->psc_workq = lagg_workq_create(xnamebuf,
		    PRI_SOFTNET, IPL_SOFTNET, WQ_MPSAFE);
	if (psc->psc_workq == NULL) {
		LAGG_LOG(sc, LOG_ERR, "workqueue create failed\n");
		kmem_free(psc, sizeof(*psc));
		return NULL;
	}

	if (ctxsiz > 0) {
		psc->psc_ctx = kmem_zalloc(ctxsiz, KM_NOSLEEP);
		if (psc->psc_ctx == NULL) {
			lagg_workq_destroy(psc->psc_workq);
			kmem_free(psc, sizeof(*psc));
			return NULL;
		}

		psc->psc_ctxsiz = ctxsiz;
	}

	PSLIST_INIT(&psc->psc_ports);
	psc->psc_psz = pserialize_create();
	mutex_init(&psc->psc_lock, MUTEX_DEFAULT, IPL_SOFTNET);
	psc->psc_softc = sc;

	return psc;
}

static void
lagg_proto_free(struct lagg_proto_softc *psc)
{

	lagg_workq_wait(psc->psc_workq, &psc->psc_work_linkspeed);
	pserialize_destroy(psc->psc_psz);
	mutex_destroy(&psc->psc_lock);
	lagg_workq_destroy(psc->psc_workq);
	PSLIST_DESTROY(&psc->psc_ports);

	if (psc->psc_ctxsiz > 0)
		kmem_free(psc->psc_ctx, psc->psc_ctxsiz);

	kmem_free(psc, sizeof(*psc));
}

static struct lagg_port *
lagg_link_active(struct lagg_proto_softc *psc,
    struct lagg_proto_port *pport, struct psref *psref)
{
	struct lagg_port *lp;
	int s;

	lp = NULL;
	s = pserialize_read_enter();

	for (;pport != NULL;
	    pport = PSLIST_READER_NEXT(pport,
	    struct lagg_proto_port, lpp_entry)) {
		if (atomic_load_relaxed(&pport->lpp_active)) {
			lp = pport->lpp_laggport;
			goto done;
		}
	}

	PSLIST_READER_FOREACH(pport, &psc->psc_ports,
	    struct lagg_proto_port, lpp_entry) {
		if (atomic_load_relaxed(&pport->lpp_active)) {
			lp = pport->lpp_laggport;
			break;
		}
	}
done:
	if (lp != NULL)
		lagg_port_getref(lp, psref);
	pserialize_read_exit(s);

	return lp;
}

int
lagg_common_allocport(struct lagg_proto_softc *psc, struct lagg_port *lp)
{
	struct lagg_proto_port *pport;

	KASSERT(LAGG_LOCKED(psc->psc_softc));

	pport = kmem_zalloc(sizeof(*pport), KM_NOSLEEP);
	if (pport == NULL)
		return ENOMEM;

	PSLIST_ENTRY_INIT(pport, lpp_entry);
	pport->lpp_laggport = lp;
	lp->lp_proto_ctx = (void *)pport;
	return 0;
}

void
lagg_common_freeport(struct lagg_proto_softc *psc, struct lagg_port *lp)
{
	struct lagg_proto_port *pport;

	pport = lp->lp_proto_ctx;
	KASSERT(!pport->lpp_running);
	lp->lp_proto_ctx = NULL;

	kmem_free(pport, sizeof(*pport));
}

static void
lagg_proto_insert_port(struct lagg_proto_softc *psc,
    struct lagg_proto_port *pport)
{
	struct lagg_proto_port *pport0;
	struct lagg_port *lp, *lp0;
	bool insert_after;

	insert_after = false;
	lp = pport->lpp_laggport;

	LAGG_PROTO_LOCK(psc);
	PSLIST_WRITER_FOREACH(pport0, &psc->psc_ports,
	    struct lagg_proto_port, lpp_entry) {
		lp0 = pport0->lpp_laggport;
		if (lp0->lp_prio > lp->lp_prio)
			break;

		if (PSLIST_WRITER_NEXT(pport0,
		    struct lagg_proto_port, lpp_entry) == NULL) {
			insert_after = true;
			break;
		}
	}

	if (pport0 == NULL) {
		PSLIST_WRITER_INSERT_HEAD(&psc->psc_ports, pport,
		    lpp_entry);
	} else if (insert_after) {
		PSLIST_WRITER_INSERT_AFTER(pport0, pport, lpp_entry);
	} else {
		PSLIST_WRITER_INSERT_BEFORE(pport0, pport, lpp_entry);
	}
	LAGG_PROTO_UNLOCK(psc);
}

static void
lagg_proto_remove_port(struct lagg_proto_softc *psc,
    struct lagg_proto_port *pport)
{

	LAGG_PROTO_LOCK(psc);
	PSLIST_WRITER_REMOVE(pport, lpp_entry);
	LAGG_PROTO_UNLOCK(psc);
	pserialize_perform(psc->psc_psz);

	/* re-initialize for reuse */
	PSLIST_ENTRY_DESTROY(pport, lpp_entry);
	PSLIST_ENTRY_INIT(pport, lpp_entry);
}

void
lagg_common_startport(struct lagg_proto_softc *psc, struct lagg_port *lp)
{
	struct lagg_proto_port *pport;

	pport = lp->lp_proto_ctx;
	lagg_proto_insert_port(psc, pport);

	LAGG_PROTO_LOCK(psc);
	pport->lpp_running = true;
	LAGG_PROTO_UNLOCK(psc);

	lagg_common_linkstate(psc, lp);
}

void
lagg_common_stopport(struct lagg_proto_softc *psc, struct lagg_port *lp)
{
	struct lagg_proto_port *pport;
	struct ifnet *ifp;

	pport = lp->lp_proto_ctx;

	LAGG_PROTO_LOCK(psc);
	pport->lpp_running = false;
	LAGG_PROTO_UNLOCK(psc);

	lagg_proto_remove_port(psc, pport);

	if (pport->lpp_active) {
		KASSERT(psc->psc_nactports > 0);
		psc->psc_nactports--;

		if (psc->psc_nactports == 0) {
			ifp = &psc->psc_softc->sc_if;
			if_link_state_change(ifp, LINK_STATE_DOWN);
		}

		pport->lpp_active = false;
	}

	lagg_workq_add(psc->psc_workq, &psc->psc_work_linkspeed);
}
static void
lagg_common_linkstate(struct lagg_proto_softc *psc, struct lagg_port *lp)
{

	IFNET_ASSERT_UNLOCKED(lp->lp_ifp);

	IFNET_LOCK(lp->lp_ifp);
	lagg_common_linkstate_ifnet_locked(psc, lp);
	IFNET_UNLOCK(lp->lp_ifp);
}

void
lagg_common_linkstate_ifnet_locked(struct lagg_proto_softc *psc, struct lagg_port *lp)
{
	struct lagg_proto_port *pport;
	struct ifnet *ifp, *ifp_port;
	struct ifmediareq ifmr;
	uint64_t linkspeed;
	bool is_active;
	int error;

	pport = lp->lp_proto_ctx;
	is_active = lagg_portactive(lp);
	ifp_port = lp->lp_ifp;

	KASSERT(IFNET_LOCKED(ifp_port));

	LAGG_PROTO_LOCK(psc);
	if (!pport->lpp_running ||
	    pport->lpp_active == is_active) {
		LAGG_PROTO_UNLOCK(psc);
		return;
	}

	ifp = &psc->psc_softc->sc_if;
	pport->lpp_active = is_active;

	if (is_active) {
		psc->psc_nactports++;
		if (psc->psc_nactports == 1)
			if_link_state_change(ifp, LINK_STATE_UP);
	} else {
		KASSERT(psc->psc_nactports > 0);
		psc->psc_nactports--;

		if (psc->psc_nactports == 0)
			if_link_state_change(ifp, LINK_STATE_DOWN);
	}
	LAGG_PROTO_UNLOCK(psc);

	memset(&ifmr, 0, sizeof(ifmr));
	error = if_ioctl(ifp_port, SIOCGIFMEDIA, (void *)&ifmr);
	if (error == 0) {
		linkspeed = ifmedia_baudrate(ifmr.ifm_active);
	} else {
		linkspeed = 0;
	}

	LAGG_PROTO_LOCK(psc);
	pport->lpp_linkspeed = linkspeed;
	LAGG_PROTO_UNLOCK(psc);
	lagg_workq_add(psc->psc_workq, &psc->psc_work_linkspeed);
}

void
lagg_common_detach(struct lagg_proto_softc *psc)
{

	lagg_proto_free(psc);
}

int
lagg_none_attach(struct lagg_softc *sc, struct lagg_proto_softc **pscp)
{

	*pscp = NULL;
	return 0;
}

int
lagg_fail_attach(struct lagg_softc *sc, struct lagg_proto_softc **xpsc)
{
	struct lagg_proto_softc *psc;
	struct lagg_failover *fovr;

	psc = lagg_proto_alloc(LAGG_PROTO_FAILOVER, sc);
	if (psc == NULL)
		return ENOMEM;

	fovr = psc->psc_ctx;
	fovr->fo_rx_all = true;
	lagg_work_set(&psc->psc_work_linkspeed,
	    lagg_fail_linkspeed_work, psc);

	*xpsc = psc;
	return 0;
}

int
lagg_fail_transmit(struct lagg_proto_softc *psc, struct mbuf *m)
{
	struct ifnet *ifp;
	struct lagg_port *lp;
	struct psref psref;

	lp = lagg_link_active(psc, NULL, &psref);
	if (lp == NULL) {
		ifp = &psc->psc_softc->sc_if;
		if_statinc(ifp, if_oerrors);
		m_freem(m);
		return ENOENT;
	}

	lagg_output(psc->psc_softc, lp, m);
	lagg_port_putref(lp, &psref);
	return 0;
}

struct mbuf *
lagg_fail_input(struct lagg_proto_softc *psc, struct lagg_port *lp,
    struct mbuf *m)
{
	struct lagg_failover *fovr;
	struct lagg_port *lp0;
	struct ifnet *ifp;
	struct psref psref;

	fovr = psc->psc_ctx;
	if (atomic_load_relaxed(&fovr->fo_rx_all))
		return m;

	lp0 = lagg_link_active(psc, NULL, &psref);
	if (lp0 == NULL) {
		goto drop;
	}

	if (lp0 != lp) {
		lagg_port_putref(lp0, &psref);
		goto drop;
	}

	lagg_port_putref(lp0, &psref);

	return m;
drop:
	ifp = &psc->psc_softc->sc_if;
	if_statinc(ifp, if_ierrors);
	m_freem(m);
	return NULL;
}

void
lagg_fail_portstat(struct lagg_proto_softc *psc, struct lagg_port *lp,
    struct laggreqport *resp)
{
	struct lagg_failover *fovr;
	struct lagg_proto_port *pport;
	struct lagg_port *lp0;
	struct psref psref;

	fovr = psc->psc_ctx;
	pport = lp->lp_proto_ctx;

	if (pport->lpp_active) {
		lp0 = lagg_link_active(psc, NULL, &psref);
		if (lp0 == lp) {
			SET(resp->rp_flags,
			    (LAGG_PORT_ACTIVE |
			    LAGG_PORT_COLLECTING |
			    LAGG_PORT_DISTRIBUTING));
		} else {
			if (fovr->fo_rx_all) {
				SET(resp->rp_flags,
				    LAGG_PORT_COLLECTING);
			}
		}

		if (lp0 != NULL)
			lagg_port_putref(lp0, &psref);
	}
}

int
lagg_fail_ioctl(struct lagg_proto_softc *psc, struct laggreqproto *lreq)
{
	struct lagg_failover *fovr;
	struct laggreq_fail *rpfail;
	int error;
	bool set;

	error = 0;
	fovr = psc->psc_ctx;
	rpfail = &lreq->rp_fail;

	switch (rpfail->command) {
	case LAGGIOC_FAILSETFLAGS:
	case LAGGIOC_FAILCLRFLAGS:
		set = (rpfail->command == LAGGIOC_FAILSETFLAGS) ?
			true : false;

		if (ISSET(rpfail->flags, LAGGREQFAIL_RXALL))
			fovr->fo_rx_all = set;
		break;
	default:
		error = ENOTTY;
		break;
	}

	return error;
}

void
lagg_fail_linkspeed_work(struct lagg_work *_lw __unused, void *xpsc)
{
	struct lagg_proto_softc *psc = xpsc;
	struct lagg_proto_port *pport;
	struct lagg_port *lp;
	struct psref psref;
	uint64_t linkspeed;

	kpreempt_disable();
	lp = lagg_link_active(psc, NULL, &psref);
	if (lp != NULL) {
		pport = lp->lp_proto_ctx;
		LAGG_PROTO_LOCK(psc);
		linkspeed = pport->lpp_linkspeed;
		LAGG_PROTO_UNLOCK(psc);
		lagg_port_putref(lp, &psref);
	} else {
		linkspeed = 0;
	}
	kpreempt_enable();

	LAGG_LOCK(psc->psc_softc);
	lagg_set_linkspeed(psc->psc_softc, linkspeed);
	LAGG_UNLOCK(psc->psc_softc);
}

int
lagg_lb_attach(struct lagg_softc *sc, struct lagg_proto_softc **xpsc)
{
	struct lagg_proto_softc *psc;
	struct lagg_lb *lb;

	psc = lagg_proto_alloc(LAGG_PROTO_LOADBALANCE, sc);
	if (psc == NULL)
		return ENOMEM;

	lb = psc->psc_ctx;
	lb->lb_pmaps.maps_activepmap = 0;
	lagg_work_set(&psc->psc_work_linkspeed,
	    lagg_lb_linkspeed_work, psc);

	*xpsc = psc;
	return 0;
}

void
lagg_lb_startport(struct lagg_proto_softc *psc, struct lagg_port *lp)
{
	struct lagg_lb *lb;
	struct lagg_portmap *pm_act, *pm_next;
	size_t n;

	lb = psc->psc_ctx;
	lagg_common_startport(psc, lp);

	LAGG_PROTO_LOCK(psc);
	pm_act = lagg_portmap_active(&lb->lb_pmaps);
	pm_next = lagg_portmap_next(&lb->lb_pmaps);

	*pm_next = *pm_act;

	n = pm_next->pm_nports;
	pm_next->pm_ports[n] = lp;

	n++;
	pm_next->pm_nports = n;

	lagg_portmap_switch(&lb->lb_pmaps);
	LAGG_PROTO_UNLOCK(psc);
	pserialize_perform(psc->psc_psz);
}

void
lagg_lb_stopport(struct lagg_proto_softc *psc, struct lagg_port *lp)
{
	struct lagg_lb *lb;
	struct lagg_portmap *pm_act, *pm_next;
	size_t i, n;

	lb = psc->psc_ctx;

	LAGG_PROTO_LOCK(psc);
	pm_act = lagg_portmap_active(&lb->lb_pmaps);
	pm_next = lagg_portmap_next(&lb->lb_pmaps);
	n = 0;

	for (i = 0; i < pm_act->pm_nports; i++) {
		if (pm_act->pm_ports[i] == lp)
			continue;

		pm_next->pm_ports[n] = pm_act->pm_ports[i];
		n++;
	}

	pm_next->pm_nports = n;

	lagg_portmap_switch(&lb->lb_pmaps);
	LAGG_PROTO_UNLOCK(psc);
	pserialize_perform(psc->psc_psz);

	lagg_common_stopport(psc, lp);
}

int
lagg_lb_transmit(struct lagg_proto_softc *psc, struct mbuf *m)
{
	struct lagg_lb *lb;
	struct lagg_portmap *pm;
	struct lagg_port *lp, *lp0;
	struct ifnet *ifp;
	struct psref psref;
	uint32_t hash;
	int s;

	lb = psc->psc_ctx;
	hash = lagg_hashmbuf(psc->psc_softc, m);

	s = pserialize_read_enter();

	pm = lagg_portmap_active(&lb->lb_pmaps);
	if (__predict_true(pm->pm_nports != 0)) {
		hash %= pm->pm_nports;
		lp0 = pm->pm_ports[hash];
		lp = lagg_link_active(psc, lp0->lp_proto_ctx, &psref);
	} else {
		lp = NULL;
	}

	pserialize_read_exit(s);

	if (__predict_false(lp == NULL)) {
		ifp = &psc->psc_softc->sc_if;
		if_statinc(ifp, if_oerrors);
		m_freem(m);
		return ENOENT;
	}

	lagg_output(psc->psc_softc, lp, m);
	lagg_port_putref(lp, &psref);

	return 0;
}

struct mbuf *
lagg_lb_input(struct lagg_proto_softc *psc __unused,
    struct lagg_port *lp __unused, struct mbuf *m)
{

	return m;
}

void
lagg_lb_portstat(struct lagg_proto_softc *psc, struct lagg_port *lp,
    struct laggreqport *resp)
{
	struct lagg_proto_port *pport;

	pport = lp->lp_proto_ctx;

	if (pport->lpp_active) {
		SET(resp->rp_flags, LAGG_PORT_ACTIVE |
		    LAGG_PORT_COLLECTING | LAGG_PORT_DISTRIBUTING);
	}
}

static void
lagg_lb_linkspeed_work(struct lagg_work *_lw __unused, void *xpsc)
{
	struct lagg_proto_softc *psc = xpsc;
	struct lagg_proto_port *pport;
	uint64_t linkspeed, l;

	linkspeed = 0;

	LAGG_PROTO_LOCK(psc); /* acquired to refer lpp_linkspeed */
	PSLIST_READER_FOREACH(pport, &psc->psc_ports,
	    struct lagg_proto_port, lpp_entry) {
		if (pport->lpp_active) {
			l = pport->lpp_linkspeed;
			linkspeed = MAX(linkspeed, l);
		}
	}
	LAGG_PROTO_UNLOCK(psc);

	LAGG_LOCK(psc->psc_softc);
	lagg_set_linkspeed(psc->psc_softc, linkspeed);
	LAGG_UNLOCK(psc->psc_softc);
}
