/*
 * elfstrchange.c by truff <truff@projet7.org>
 * Change the value of a symbol name in the .strtab section
 *
 * Usage: elfstrchange elf_object sym_name sym_name_replaced
 *
 */

#include <stdlib.h>
#include <stdio.h>
#include <elf.h>

#define FATAL(X) { perror (X);exit (EXIT_FAILURE); }


int ElfGetSectionName (FILE *fd, Elf32_Word sh_name, 
                       Elf32_Shdr *shstrtable, char *res, size_t len);
                       
Elf32_Off ElfGetSymbolByName (FILE *fd, Elf32_Shdr *symtab, 
                       Elf32_Shdr *strtab, char *name, Elf32_Sym *sym);
                       
Elf32_Off ElfGetSymbolName (FILE *fd, Elf32_Word sym_name, 
                       Elf32_Shdr *strtable, char *res, size_t len);


int main (int argc, char **argv)
{
  int i;
  int len = 0;
  char *string;
  FILE *fd;
  Elf32_Ehdr hdr;
  Elf32_Shdr symtab, strtab;
  Elf32_Sym sym;
  Elf32_Off symoffset;

  fd = fopen (argv[1], "r+");
  if (fd == NULL)
    FATAL ("fopen");

  if (fread (&hdr, sizeof (Elf32_Ehdr), 1, fd) < 1)
    FATAL ("Elf header corrupted");

  if (ElfGetSectionByName (fd, &hdr, ".symtab", &symtab) == -1)
  {
    fprintf (stderr, "Can't get .symtab section\n");
    exit (EXIT_FAILURE);
  }
    
  if (ElfGetSectionByName (fd, &hdr, ".strtab", &strtab) == -1)
  {
    fprintf (stderr, "Can't get .strtab section\n");
    exit (EXIT_FAILURE);
  }
    

  symoffset = ElfGetSymbolByName (fd, &symtab, &strtab, argv[2], &sym);
  if (symoffset == -1)
  {
    fprintf (stderr, "Symbol %s not found\n", argv[2]);
    exit (EXIT_FAILURE);
  }
  
  
  printf ("[+] Symbol %s located at 0x%x\n", argv[2], symoffset);
  
  if (fseek (fd, symoffset, SEEK_SET) == -1)
    FATAL ("fseek");

  if (fwrite (argv[3], 1, strlen(argv[3]), fd) < strlen (argv[3]))
    FATAL ("fwrite");
  
  printf ("[+] .strtab entry overwriten with %s\n", argv[3]);
  
  fclose (fd);

  return EXIT_SUCCESS;
}

Elf32_Off ElfGetSymbolByName (FILE *fd, Elf32_Shdr *symtab, 
            Elf32_Shdr *strtab, char *name, Elf32_Sym *sym)
{
  int i;
  char symname[255];
  Elf32_Off offset;

  for (i=0; i<(symtab->sh_size/symtab->sh_entsize); i++)
  {
    if (fseek (fd, symtab->sh_offset + (i * symtab->sh_entsize), 
               SEEK_SET) == -1)
      FATAL ("fseek");
    
    if (fread (sym, sizeof (Elf32_Sym), 1, fd) < 1)
      FATAL ("Symtab corrupted");
    
    memset (symname, 0, sizeof (symname));
    offset = ElfGetSymbolName (fd, sym->st_name, 
                        strtab, symname, sizeof (symname));
    if (!strcmp (symname, name))
      return offset;
  }
  
  return -1;
}


int ElfGetSectionByIndex (FILE *fd, Elf32_Ehdr *ehdr, Elf32_Half index, 
    Elf32_Shdr *shdr)
{
  if (fseek (fd, ehdr->e_shoff + (index * ehdr->e_shentsize), 
             SEEK_SET) == -1)
    FATAL ("fseek");
  
  if (fread (shdr, sizeof (Elf32_Shdr), 1, fd) < 1)
    FATAL ("Sections header corrupted");

  return 0;
}
  

int ElfGetSectionByName (FILE *fd, Elf32_Ehdr *ehdr, char *section, 
                         Elf32_Shdr *shdr)
{
  int i;
  char name[255];
  Elf32_Shdr shstrtable;

  /*
   * Get the section header string table
   */
  ElfGetSectionByIndex (fd, ehdr, ehdr->e_shstrndx, &shstrtable);
  
  memset (name, 0, sizeof (name));

  for (i=0; i<ehdr->e_shnum; i++)
  {
    if (fseek (fd, ehdr->e_shoff + (i * ehdr->e_shentsize), 
               SEEK_SET) == -1)
      FATAL ("fseek");
    
    if (fread (shdr, sizeof (Elf32_Shdr), 1, fd) < 1)
      FATAL ("Sections header corrupted");
    
    ElfGetSectionName (fd, shdr->sh_name, &shstrtable, 
                       name, sizeof (name));
    if (!strcmp (name, section))
    {
      return 0;
    }
  }
  return -1;
}


int ElfGetSectionName (FILE *fd, Elf32_Word sh_name, 
    Elf32_Shdr *shstrtable, char *res, size_t len)
{
  size_t i = 0;
  
  if (fseek (fd, shstrtable->sh_offset + sh_name, SEEK_SET) == -1)
    FATAL ("fseek");
  
  while ((i < len) || *res == '\0')
  {
    *res = fgetc (fd);
    i++;
    res++;
  }
  
  return 0;
}


Elf32_Off ElfGetSymbolName (FILE *fd, Elf32_Word sym_name, 
    Elf32_Shdr *strtable, char *res, size_t len)
{
  size_t i = 0;
  
  if (fseek (fd, strtable->sh_offset + sym_name, SEEK_SET) == -1)
    FATAL ("fseek");
  
  while ((i < len) || *res == '\0')
  {
    *res = fgetc (fd);
    i++;
    res++;
  }
  
  return (strtable->sh_offset + sym_name);
}
/* EOF */
