
/*
 * pthread_create overload function
 * Copyleft Michael Meier 2008 - released under GPL v2.
 * This will pin every thread that is created except the first one (because
 * that seems to be the "master" in OpenMP binaries created by Intel
 * Compilers).
 * It will automatically get its current CPUset from the system and use only
 * these CPUs in a round-robin way, so you will most likely want to combine
 * this with "taskset".
 *
 * Compile with something like:
 *   gcc -Wall -O2 -D_GNU_SOURCE -o ptoverride.so -shared -fPIC -ldl -lpthread ./pthread-overload.c
 * Use with something like:
 *   LD_PRELOAD=./ptoverride.so OMP_NUM_THREADS=2 taskset -c 0,2 ./youropenmpbinary
 *
 * If libpthread.so cannot be found, you can specify its location at
 * compile-time by adding the following parameter to the compile command:
 *   -DLIBPTHREADLOCATION=/where/is/libpthread.so
 *
 * Advanced users can alter the pinning behaviour to adapt it to different
 * OpenMP variants or hybrid programs. Three Environment Variables are
 * evaluated by the pinner for that purpose:
 *  PINOMP_MASK    is a decimal or hex (with leading 0x) number interpreted
 *                 as a bitmask. Threads that have their corresponding bit in
 *                 the mask set will not be pinned.
 *  PINOMP_SKIP    is a decimal number. The thread with that number will not
 *                 be pinned.
 *  PINOMP_CPUS    the CPUs to use, in the right order, seperated by commas
 *  PINOMP_VERBOSE is a decimal that sets the verbosity of debug output. defaults to 2.
 */

#include <stdio.h>
#include <stdlib.h>
#include <dlfcn.h>
#include <sched.h>
#include <sys/types.h>
#include <dirent.h>
#include <unistd.h>
#include <string.h>

extern int pthread_setaffinity_np(unsigned long th, int cpusetsize, cpu_set_t * cpuset);

#define str(x) #x

static char * sosearchpaths[] = {
#ifdef LIBPTHREADLOCATION
	str(LIBPTHREADLOCATION),
#endif
	"/lib64/tls/libpthread.so.0",	/* sles9 x86_64 */
	"libpthread.so.0",		/* Ubuntu and other proper distributions */
	NULL
};
static int verblevel = -1;

int pthread_create(void * thread, void * attr, void * (*start_routine)(void *), void * arg)
{
  void *handle;
  char *error;
  int (*rptc) (void *, void *, void * (*start_routine)(void *), void *);
  int ret;
  static int reallpthrindex = 0;
  static int npinned = 0;
  static cpu_set_t mask;
  static int lastpin = 0;
  static pid_t mainpid;
  static unsigned long pinningskipmask = 0;
  static int useexplicitcpus = 0;
  static char * cpustring =  NULL;
  static char * curcpustr = NULL;

  if (verblevel == -1) {
    char * PINOMP_VERBOSE = getenv("PINOMP_VERBOSE");
    if (PINOMP_VERBOSE == NULL) {
      verblevel = 2;
    } else {
      verblevel = strtol(PINOMP_VERBOSE, NULL, 10);
      if (verblevel < 0) { verblevel = 0; }
    }
  }
  if (npinned == 0) {
    char * PINOMP_CPUS = getenv("PINOMP_CPUS");
    if (PINOMP_CPUS != NULL) {
      cpustring = strdup(PINOMP_CPUS);
      curcpustr = cpustring;
      useexplicitcpus = 1;
    }
    if (pinningskipmask == 0) { /* Could just as well use if (1) ... */
      char * PINOMP_MASK = getenv("PINOMP_MASK");
      char * PINOMP_SKIP = getenv("PINOMP_SKIP");
      if ((PINOMP_MASK == NULL) && (PINOMP_SKIP == NULL)) {
        pinningskipmask = 2; /* Default - makes it behave just like before, skipping the first thread */
      } else {
        if (PINOMP_MASK != NULL) {
          unsigned long toskip = strtoul(PINOMP_MASK, NULL, 16); /* accepts both 0x* and decimal on GNU! */
          pinningskipmask = toskip;
        }
        if (PINOMP_SKIP != NULL) {
          char * rest = PINOMP_SKIP;
          do {
            unsigned long toskip = strtoul(rest, &rest, 10);
            pinningskipmask |= (1UL << toskip);
            if (*rest != '\0') { rest++; }
          } while (*rest != '\0');
        }
      }
      if (verblevel > 1) {
        printf("[pthread wrapper] Pinning Skip Mask: 0x%lx\n", pinningskipmask);
      }
    }
    npinned++;
    CPU_ZERO(&mask);
    ret = sched_getaffinity(getpid(), sizeof(mask), &mask);
    if (ret) {
      printf("[pthread wrapper] WARNING: sched_get_affinity returned error code %d, cannot pin correctly.\n", ret);
    } else {
      int j;
      if (verblevel > 1) {
        printf("[pthread wrapper] Using CPUs: ");
      }
      for (j = 0; j < CPU_SETSIZE; j++) {
        if (CPU_ISSET(j, &mask)) {
          lastpin = j;
          if (verblevel > 1) {
            printf(" %d", j);
          }
        }
      }
      if (verblevel > 0) {
        printf("\n[pthread wrapper] ");
      }
      mainpid = getpid();
      if ((pinningskipmask & 1UL) != 0) { /* npinned has already been increased so cannot be used! */
        if (verblevel > 0) {
          printf("Main PID: %d -> SKIP!\n", mainpid);
        }
      } else {
        cpu_set_t mymask;
        int usecpu;
        if (useexplicitcpus) {
          usecpu = strtoul(curcpustr, &curcpustr, 10);
          if ((curcpustr == NULL) || (*curcpustr == '\0')) {
            curcpustr = cpustring;
          } else {
            curcpustr++;
            if ((curcpustr == NULL) || (*curcpustr == '\0')) {
              curcpustr = cpustring;
            }
          }
        } else {
          usecpu = ((lastpin + 1) % CPU_SETSIZE);
          while ((usecpu != lastpin) && (!CPU_ISSET(usecpu, &mask))) {
            usecpu = ((usecpu + 1) % CPU_SETSIZE);
          }
        }
        lastpin = usecpu;
        CPU_ZERO(&mymask);
        CPU_SET(usecpu, &mymask);
        if (verblevel > 0) {
          printf("Main PID: %d -> core %d - ", mainpid, usecpu);
        }
        if (sched_setaffinity(mainpid, sizeof(mymask), &mymask)) {
          perror("sched_setaffinity failed");
        } else {
          if (verblevel > 0) {
            printf("OK\n");
          }
        }
      }
    }
  }
  
  if (verblevel > 0) {
    printf("[pthread wrapper] ");
  }
  do {
    handle = dlopen(sosearchpaths[reallpthrindex], RTLD_LAZY);
    if (handle) {
      if (verblevel > 1) {
        printf("[Notice: Using %s] ", sosearchpaths[reallpthrindex]);
      }
      break;
    }
    if (sosearchpaths[reallpthrindex] != NULL) {
      reallpthrindex++;
    }
  } while (sosearchpaths[reallpthrindex] != NULL);
  if (!handle) {
    printf("%s\n", dlerror());
    return -1;
  }
  dlerror();    /* Clear any existing error */
  *(void **) (&rptc) = dlsym(handle, "pthread_create");
  if ((error = dlerror()) != NULL)  {
    printf("%s\n", error);
    return -2;
  }
  ret = (*rptc)(thread, attr, start_routine, arg);
  if (ret == 0) {  /* Successful thread creation. Pin the bastard. */
    unsigned long * pid = (unsigned long *)thread;
    if ((pinningskipmask & (1UL << npinned)) != 0) {
      if (verblevel > 0) {
        printf("threadid 0x%lx -> SKIP!\n", *pid);
      }
    } else {
      cpu_set_t mymask;
      int usecpu;
      if (useexplicitcpus) {
        usecpu = strtoul(curcpustr, &curcpustr, 10);
        if ((curcpustr == NULL) || (*curcpustr == '\0')) {
          curcpustr = cpustring;
        } else {
          curcpustr++;
          if ((curcpustr == NULL) || (*curcpustr == '\0')) {
            curcpustr = cpustring;
          }
        }
      } else {
        usecpu = ((lastpin + 1) % CPU_SETSIZE);
        while ((usecpu != lastpin) && (!CPU_ISSET(usecpu, &mask))) {
          usecpu = ((usecpu + 1) % CPU_SETSIZE);
        }
      }
      lastpin = usecpu;
      CPU_ZERO(&mymask);
      CPU_SET(usecpu, &mymask);
      if (verblevel > 0) {
        printf("threadid 0x%lx -> core %d - ", *pid, usecpu);
      }
      if (pthread_setaffinity_np(*pid, sizeof(mymask), &mymask)) {
        perror("pthread_setaffinity_np failed");
      } else {
        if (verblevel > 0) {
          printf("OK\n");
        }
      }
    }
  }
  npinned++;
  dlclose(handle);
  return ret;
}


