$OpenBSD: patch-src_rttable_c,v 1.1 2021/01/12 17:59:49 sthen Exp $

Index: src/rttable.c
--- src/rttable.c.orig
+++ src/rttable.c
@@ -39,17 +39,23 @@
 */
 
 #include "igmpproxy.h"
+#include <sys/queue.h>
+#include <sys/tree.h>
 
 #define MAX_ORIGINS 4
 
 /**
 *   Routing table structure definition. Double linked list...
 */
+struct Origin {
+    TAILQ_ENTRY(Origin) next;
+    uint32_t        originAddr;
+    int             flood;
+    uint32_t        pktcnt;
+};
 struct RouteTable {
-    struct RouteTable   *nextroute;     // Pointer to the next group in line.
-    struct RouteTable   *prevroute;     // Pointer to the previous group in line.
+    RB_ENTRY(RouteTable) entry;
     uint32_t            group;          // The group to route
-    uint32_t            originAddrs[MAX_ORIGINS]; // The origin adresses (only set on activated routes)
     uint32_t            vifBits;        // Bits representing recieving VIFs.
 
     // Keeps the upstream membership state...
@@ -60,22 +66,54 @@ struct RouteTable {
     uint32_t            ageVifBits;     // Bits representing aging VIFs.
     int                 ageValue;       // Downcounter for death.
     int                 ageActivity;    // Records any acitivity that notes there are still listeners.
+    TAILQ_HEAD(originhead, Origin) originList; // The origin adresses (non-empty on activated routes)
 
     // Keeps downstream hosts information
     uint32_t            downstreamHostsHashSeed;
     uint8_t             downstreamHostsHashTable[];
 };
+RB_HEAD(rtabletree, RouteTable) routing_table =
+    RB_INITIALIZER(&routing_table);
 
-
-// Keeper for the routing table...
-static struct RouteTable   *routing_table;
-
 // Prototypes
 void logRouteTable(const char *header);
 int internAgeRoute(struct RouteTable *croute);
-int internUpdateKernelRoute(struct RouteTable *route, int activate);
+int internUpdateKernelRoute(struct RouteTable *route, int activate, struct Origin *);
 
+int                   rtable_cmp(struct RouteTable *, struct RouteTable *);
+struct RouteTable    *rtable_add(struct RouteTable *);
+void                  rtable_remove(struct RouteTable *);
+RB_GENERATE(rtabletree, RouteTable, entry, rtable_cmp);
 
+int
+rtable_cmp(struct RouteTable *rt, struct RouteTable *rtn)
+{
+    if (rt->group < rtn->group)
+        return (-1);
+
+    return (rt->group > rtn->group);
+}
+
+struct RouteTable *
+rtable_add(struct RouteTable *rt)
+{
+    return (RB_INSERT(rtabletree, &routing_table, rt));
+}
+
+void
+rtable_remove(struct RouteTable *rt)
+{
+    struct Origin        *o;
+
+        while ((o = TAILQ_FIRST(&rt->originList))) {
+        TAILQ_REMOVE(&rt->originList, o, next);
+        free(o);
+        }
+
+    RB_REMOVE(rtabletree, &routing_table, rt);
+    free(rt);
+}
+
 /**
 *   Functions for downstream hosts hash table
 */
@@ -120,9 +158,6 @@ void initRouteTable(void) {
     unsigned Ix;
     struct IfDesc *Dp;
 
-    // Clear routing table...
-    routing_table = NULL;
-
     // Join the all routers group on downstream vifs...
     for ( Ix = 0; (Dp = getIfByIx(Ix)); Ix++ ) {
         // If this is a downstream vif, we should join the All routers group...
@@ -213,29 +248,24 @@ static void sendJoinLeaveUpstream(struct RouteTable* r
 *   Clear all routes from routing table, and alerts Leaves upstream.
 */
 void clearAllRoutes(void) {
-    struct RouteTable   *croute, *remainroute;
+    struct RouteTable   *croute;
 
     // Loop through all routes...
-    for(croute = routing_table; croute; croute = remainroute) {
-
-        remainroute = croute->nextroute;
-
+    while ((croute = RB_ROOT(&routing_table)) != NULL) {
         // Log the cleanup in debugmode...
         my_log(LOG_DEBUG, 0, "Removing route entry for %s",
                      inetFmt(croute->group, s1));
 
         // Uninstall current route
-        if(!internUpdateKernelRoute(croute, 0)) {
+        if(!internUpdateKernelRoute(croute, 0, NULL)) {
             my_log(LOG_WARNING, 0, "The removal from Kernel failed.");
         }
 
         // Send Leave message upstream.
         sendJoinLeaveUpstream(croute, 0);
 
-        // Clear memory, and set pointer to next route...
-        free(croute);
+        rtable_remove(croute);
     }
-    routing_table = NULL;
 
     // Send a notice that the routing table is empty...
     my_log(LOG_NOTICE, 0, "All routes removed. Routing table is empty.");
@@ -246,15 +276,10 @@ void clearAllRoutes(void) {
 *   Route Descriptor.
 */
 static struct RouteTable *findRoute(uint32_t group) {
-    struct RouteTable*  croute;
+    struct RouteTable       key;
 
-    for(croute = routing_table; croute; croute = croute->nextroute) {
-        if(croute->group == group) {
-            return croute;
-        }
-    }
-
-    return NULL;
+    key.group = group;
+    return (RB_FIND(rtabletree, &routing_table, &key));
 }
 
 /**
@@ -293,10 +318,8 @@ int insertRoute(uint32_t group, int ifx, uint32_t src)
         newroute = (struct RouteTable*)malloc(sizeof(struct RouteTable) + (conf->fastUpstreamLeave ? conf->downstreamHostsHashTableSize : 0));
         // Insert the route desc and clear all pointers...
         newroute->group      = group;
-        memset(newroute->originAddrs, 0, MAX_ORIGINS * sizeof(newroute->originAddrs[0]));
-        newroute->nextroute  = NULL;
-        newroute->prevroute  = NULL;
         newroute->upstrVif   = -1;
+        TAILQ_INIT(&newroute->originList);
 
         if(conf->fastUpstreamLeave) {
             // Init downstream hosts bit hash table
@@ -321,54 +344,13 @@ int insertRoute(uint32_t group, int ifx, uint32_t src)
             BIT_SET(newroute->vifBits, ifx);
         }
 
-        // Check if there is a table already....
-        if(routing_table == NULL) {
-            // No location set, so insert in on the table top.
-            routing_table = newroute;
-            my_log(LOG_DEBUG, 0, "No routes in table. Insert at beginning.");
-        } else {
-
-            my_log(LOG_DEBUG, 0, "Found existing routes. Find insert location.");
-
-            // Check if the route could be inserted at the beginning...
-            if(routing_table->group > group) {
-                my_log(LOG_DEBUG, 0, "Inserting at beginning, before route %s",inetFmt(routing_table->group,s1));
-
-                // Insert at beginning...
-                newroute->nextroute = routing_table;
-                newroute->prevroute = NULL;
-                routing_table = newroute;
-
-                // If the route has a next node, the previous pointer must be updated.
-                if(newroute->nextroute != NULL) {
-                    newroute->nextroute->prevroute = newroute;
-                }
-
-            } else {
-
-                // Find the location which is closest to the route.
-                for( croute = routing_table; croute->nextroute != NULL; croute = croute->nextroute ) {
-                    // Find insert position.
-                    if(croute->nextroute->group > group) {
-                        break;
-                    }
-                }
-
-                my_log(LOG_DEBUG, 0, "Inserting after route %s",inetFmt(croute->group,s1));
-
-                // Insert after current...
-                newroute->nextroute = croute->nextroute;
-                newroute->prevroute = croute;
-                if(croute->nextroute != NULL) {
-                    croute->nextroute->prevroute = newroute;
-                }
-                croute->nextroute = newroute;
-            }
+        if ((croute = rtable_add(newroute)) != NULL)
+            free(newroute);
+        else {
+            // Set the new route as the current...
+            croute = newroute;
         }
 
-        // Set the new route as the current...
-        croute = newroute;
-
         // Log the cleanup in debugmode...
         my_log(LOG_INFO, 0, "Inserted route table entry for %s on VIF #%d",
             inetFmt(croute->group, s1),ifx);
@@ -391,7 +373,7 @@ int insertRoute(uint32_t group, int ifx, uint32_t src)
             inetFmt(croute->group, s1), ifx);
 
         // Update route in kernel...
-        if(!internUpdateKernelRoute(croute, 1)) {
+        if(!internUpdateKernelRoute(croute, 1, NULL)) {
             my_log(LOG_WARNING, 0, "The insertion into Kernel failed.");
             return 0;
         }
@@ -413,7 +395,7 @@ int insertRoute(uint32_t group, int ifx, uint32_t src)
 *   activated, it's reinstalled in the kernel. If
 *   the route is activated, no originAddr is needed.
 */
-int activateRoute(uint32_t group, uint32_t originAddr, int upstrVif) {
+int activateRoute(uint32_t group, uint32_t originAddr, int downIf) {
     struct RouteTable*  croute;
     int result = 0;
 
@@ -432,43 +414,37 @@ int activateRoute(uint32_t group, uint32_t originAddr,
     }
 
     if(croute != NULL) {
+        struct Origin *o = NULL;
+        int found = 0;
+
         // If the origin address is set, update the route data.
         if(originAddr > 0) {
-            // find this origin, or an unused slot
-            int i;
-            for (i = 0; i < MAX_ORIGINS; i++) {
-                // unused slots are at the bottom, so we can't miss this origin
-                if (croute->originAddrs[i] == originAddr || croute->originAddrs[i] == 0) {
+            TAILQ_FOREACH(o, &croute->originList, next) {
+                my_log(LOG_INFO, 0, "Origin for route %s have %s, new %s",
+                    inetFmt(croute->group, s1),
+                    inetFmt(o->originAddr, s2),
+                    inetFmt(originAddr, s3));
+                if (o->originAddr==originAddr) {
+                    found++;
                     break;
                 }
             }
-
-            if (i == MAX_ORIGINS) {
-                i = MAX_ORIGINS - 1;
-
-                my_log(LOG_WARNING, 0, "Too many origins for route %s; replacing %s with %s",
+            if (!found) {
+                my_log(LOG_NOTICE, 0, "New origin for route %s is %s, flood %d",
                     inetFmt(croute->group, s1),
-                    inetFmt(croute->originAddrs[i], s2),
-                    inetFmt(originAddr, s3));
+                    inetFmt(originAddr, s3), downIf);
+                o = malloc(sizeof(*o));
+                o->originAddr = originAddr;
+                o->flood = downIf;
+                o->pktcnt = 0;
+                TAILQ_INSERT_TAIL(&croute->originList, o, next);
+            } else {
+                my_log(LOG_INFO, 0, "Have origin for route %s at %s, pktcnt %d",
+                    inetFmt(croute->group, s1),
+                    inetFmt(o->originAddr, s3),
+                    o->pktcnt);
             }
-
-            // set origin
-            croute->originAddrs[i] = originAddr;
-
-            // move it to the top
-            while (i > 0) {
-                uint32_t t = croute->originAddrs[i - 1];
-                croute->originAddrs[i - 1] = croute->originAddrs[i];
-                croute->originAddrs[i] = t;
-                i--;
-            }
         }
-        croute->upstrVif = upstrVif;
-
-        // Only update kernel table if there are listeners !
-        if(croute->vifBits > 0) {
-            result = internUpdateKernelRoute(croute, 1);
-        }
     }
     logRouteTable("Activate Route");
 
@@ -485,12 +461,7 @@ void ageActiveRoutes(void) {
 
     my_log(LOG_DEBUG, 0, "Aging routes in table.");
 
-    // Scan all routes...
-    for( croute = routing_table; croute != NULL; croute = nroute ) {
-
-        // Keep the next route (since current route may be removed)...
-        nroute = croute->nextroute;
-
+    RB_FOREACH_SAFE(croute, rtabletree, &routing_table, nroute) {
         // Run the aging round algorithm.
         if(croute->upstrState != ROUTESTATE_CHECK_LAST_MEMBER) {
             // Only age routes if Last member probe is not active...
@@ -602,7 +573,7 @@ static int removeRoute(struct RouteTable*  croute) {
     //BIT_ZERO(croute->vifBits);
 
     // Uninstall current route from kernel
-    if(!internUpdateKernelRoute(croute, 0)) {
+    if(!internUpdateKernelRoute(croute, 0, NULL)) {
         my_log(LOG_WARNING, 0, "The removal from Kernel failed.");
         result = 0;
     }
@@ -614,24 +585,8 @@ static int removeRoute(struct RouteTable*  croute) {
         sendJoinLeaveUpstream(croute, 0);
     }
 
-    // Update pointers...
-    if(croute->prevroute == NULL) {
-        // Topmost node...
-        if(croute->nextroute != NULL) {
-            croute->nextroute->prevroute = NULL;
-        }
-        routing_table = croute->nextroute;
+    rtable_remove(croute);
 
-    } else {
-        croute->prevroute->nextroute = croute->nextroute;
-        if(croute->nextroute != NULL) {
-            croute->nextroute->prevroute = croute->prevroute;
-        }
-    }
-    // Free the memory, and set the route to NULL...
-    free(croute);
-    croute = NULL;
-
     logRouteTable("Remove route");
 
     return result;
@@ -676,6 +631,37 @@ int internAgeRoute(struct RouteTable*  croute) {
         }
     }
 
+    {
+        struct Origin *o, *nxt;
+        struct sioc_sg_req sg_req;
+
+        sg_req.grp.s_addr = croute->group;
+        for (o = TAILQ_FIRST(&croute->originList); o; o = nxt) {
+            nxt = TAILQ_NEXT(o, next);
+            sg_req.src.s_addr = o->originAddr;
+            if (ioctl(MRouterFD, SIOCGETSGCNT, (char *)&sg_req) < 0) {
+                my_log(LOG_WARNING, errno, "%s (%s %s)",
+                        "age_table_entry: SIOCGETSGCNT failing for",
+                        inetFmt(o->originAddr, s1),
+                        inetFmt(croute->group, s2));
+                /* Make sure it gets deleted below */
+                sg_req.pktcnt = o->pktcnt;
+            }
+            my_log(LOG_DEBUG, 0, "Aging Origin %s Dst %s PktCnt %d -> %d",
+                    inetFmt(o->originAddr, s1), inetFmt(croute->group, s2),
+            o->pktcnt, sg_req.pktcnt);
+            if (sg_req.pktcnt == o->pktcnt) {
+                /* no traffic, remove from kernel cache */
+                internUpdateKernelRoute(croute, 0, o);
+                TAILQ_REMOVE(&croute->originList, o, next);
+                free(o);
+            } else {
+                o->pktcnt = sg_req.pktcnt;
+            }
+        }
+    }
+
+
     // If the aging counter has reached zero, its time for updating...
     if(croute->ageValue == 0) {
         // Check for activity in the aging process,
@@ -685,7 +671,7 @@ int internAgeRoute(struct RouteTable*  croute) {
                          inetFmt(croute->group,s1));
 
             // Just update the routing settings in kernel...
-            internUpdateKernelRoute(croute, 1);
+            internUpdateKernelRoute(croute, 1, NULL);
 
             // We append the activity counter to the age, and continue...
             croute->ageValue = croute->ageActivity;
@@ -711,39 +697,61 @@ int internAgeRoute(struct RouteTable*  croute) {
 /**
 *   Updates the Kernel routing table. If activate is 1, the route
 *   is (re-)activated. If activate is false, the route is removed.
+*   if 'origin' is given, only the route with 'origin' will be
+*   updated, otherwise all MFC routes for the group will updated.
 */
-int internUpdateKernelRoute(struct RouteTable *route, int activate) {
+int internUpdateKernelRoute(struct RouteTable *route, int activate, struct Origin *origin) {
     struct   MRouteDesc mrDesc;
     struct   IfDesc     *Dp;
     unsigned            Ix;
-    int i;
+    struct Origin *o;
 
-    for (i = 0; i < MAX_ORIGINS; i++) {
-        if (route->originAddrs[i] == 0 || route->upstrVif == -1) {
+    if (TAILQ_EMPTY(&route->originList)) {
+        my_log(LOG_NOTICE, 0, "Route is not active. No kernel updates done.");
+        return 1;
+    }
+
+    TAILQ_FOREACH(o, &route->originList, next) {
+        if (origin && origin != o)
             continue;
-        }
 
         // Build route descriptor from table entry...
         // Set the source address and group address...
         mrDesc.McAdr.s_addr     = route->group;
-        mrDesc.OriginAdr.s_addr = route->originAddrs[i];
+        mrDesc.OriginAdr.s_addr = o->originAddr;
 
         // clear output interfaces
         memset( mrDesc.TtlVc, 0, sizeof( mrDesc.TtlVc ) );
 
-        my_log(LOG_DEBUG, 0, "Vif bits : 0x%08x", route->vifBits);
+        my_log(LOG_DEBUG, 0, "Origin %s Vif bits : 0x%08x", inetFmt(o->originAddr, s1), route->vifBits);
 
         mrDesc.InVif = route->upstrVif;
 
         // Set the TTL's for the route descriptor...
-        for ( Ix = 0; (Dp = getIfByIx(Ix)); Ix++ ) {
-            if(Dp->state == IF_STATE_UPSTREAM) {
-                continue;
+        for ( Ix = 0; (Dp = getIfByIx( Ix )); Ix++ ) {
+            if (o->flood >= 0) {
+                if(Ix == o->flood) {
+                    my_log(LOG_DEBUG, 0, "Identified Input VIF #%d as DOWNSTREAM.", Dp->index);
+                    mrDesc.InVif = Dp->index;
+                }
+                else if(Dp->state == IF_STATE_UPSTREAM) {
+                    my_log(LOG_DEBUG, 0, "Setting TTL for UPSTREAM Vif %d to %d", Dp->index, Dp->threshold);
+                    mrDesc.TtlVc[ Dp->index ] = Dp->threshold;
+                }
+                else if(BIT_TST(route->vifBits, Dp->index)) {
+                    my_log(LOG_DEBUG, 0, "Setting TTL for DOWNSTREAM Vif %d to %d", Dp->index, Dp->threshold);
+                    mrDesc.TtlVc[ Dp->index ] = Dp->threshold;
+                }
+                } else {
+                if(Dp->state == IF_STATE_UPSTREAM) {
+                    my_log(LOG_DEBUG, 0, "Identified VIF #%d as upstream.", Dp->index);
+                    mrDesc.InVif = Dp->index;
+                }
+                else if(BIT_TST(route->vifBits, Dp->index)) {
+                    my_log(LOG_DEBUG, 0, "Setting TTL for Vif %d to %d", Dp->index, Dp->threshold);
+                    mrDesc.TtlVc[ Dp->index ] = Dp->threshold;
+                }
             }
-            else if(BIT_TST(route->vifBits, Dp->index)) {
-                my_log(LOG_DEBUG, 0, "Setting TTL for Vif %d to %d", Dp->index, Dp->threshold);
-                mrDesc.TtlVc[ Dp->index ] = Dp->threshold;
-            }
         }
 
         // Do the actual Kernel route update...
@@ -765,7 +773,7 @@ int internUpdateKernelRoute(struct RouteTable *route, 
 */
 void logRouteTable(const char *header) {
         struct Config       *conf = getCommonConfig();
-        struct RouteTable   *croute = routing_table;
+        struct RouteTable   *croute = RB_ROOT(&routing_table);
         unsigned            rcount = 0;
 
         my_log(LOG_DEBUG, 0, "");
@@ -774,30 +782,22 @@ void logRouteTable(const char *header) {
         if(croute==NULL) {
             my_log(LOG_DEBUG, 0, "No routes in table...");
         } else {
-            do {
-                char st = 'I';
-                char src[MAX_ORIGINS * 30 + 1];
-                src[0] = '\0';
-                int i;
-
-                for (i = 0; i < MAX_ORIGINS; i++) {
-                    if (croute->originAddrs[i] == 0) {
-                        continue;
-                    }
-                    st = 'A';
-                    sprintf(src + strlen(src), "Src%d: %s, ", i, inetFmt(croute->originAddrs[i], s1));
-                }
-
-                my_log(LOG_DEBUG, 0, "#%d: %sDst: %s, Age:%d, St: %c, OutVifs: 0x%08x, dHosts: %s",
-                    rcount, src, inetFmt(croute->group, s2),
-                    croute->ageValue, st,
+            RB_FOREACH(croute, rtabletree, &routing_table) {
+                my_log(LOG_DEBUG, 0, "#%d: Dst: %s, Age:%d, St: %s, OutVifs: 0x%08x, dHosts: %s",
+                    rcount, inetFmt(croute->group, s2),
+                    croute->ageValue,
+                    (TAILQ_EMPTY(&croute->originList) ? "I":"A"),
                     croute->vifBits,
                     !conf->fastUpstreamLeave ? "not tracked" : testNoDownstreamHost(conf, croute) ? "no" : "yes");
-
-                croute = croute->nextroute;
-
+                {
+                    struct Origin *o;
+                    TAILQ_FOREACH(o, &croute->originList, next) {
+                        my_log(LOG_DEBUG, 0, "#%d: Origin: %s floodIf %d pktcnt %d",
+                            rcount, inetFmt(o->originAddr, s1), o->flood, o->pktcnt);
+                    }
+                }
                 rcount++;
-            } while ( croute != NULL );
+            }
         }
 
         my_log(LOG_DEBUG, 0, "-----------------------------------------------------");
