#!/usr/bin/env python

# Copyright (C) 2017 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import atexit
import hashlib
import os
import shutil
import signal
import subprocess
import sys
import tempfile
import time
import urllib

TRACE_TO_TEXT_SHAS = {
    'linux': '74cb40d9413d7ad756e327e09e46a8adde90db92',
    'mac': '47045bc6abd4f9113969211f37fdd2259b0c1cc3',
}
TRACE_TO_TEXT_PATH = tempfile.gettempdir()
TRACE_TO_TEXT_BASE_URL = ('https://storage.googleapis.com/perfetto/')

NULL = open(os.devnull)
NOOUT = {
    'stdout': NULL,
    'stderr': NULL,
}


def check_hash(file_name, sha_value):
  with open(file_name, 'rb') as fd:
    # TODO(fmayer): Chunking.
    file_hash = hashlib.sha1(fd.read()).hexdigest()
    return file_hash == sha_value


def load_trace_to_text(platform):
  sha_value = TRACE_TO_TEXT_SHAS[platform]
  file_name = 'trace_to_text-' + platform + '-' + sha_value
  local_file = os.path.join(TRACE_TO_TEXT_PATH, file_name)

  if os.path.exists(local_file):
    if not check_hash(local_file, sha_value):
      os.remove(local_file)
    else:
      return local_file

  url = TRACE_TO_TEXT_BASE_URL + file_name
  urllib.urlretrieve(url, local_file)
  if not check_hash(local_file, sha_value):
    os.remove(local_file)
    raise ValueError("Invalid signature.")
  os.chmod(local_file, 0o755)
  return local_file


PACKAGES_LIST_CFG = '''data_sources {
  config {
    name: "android.packages_list"
  }
}
'''

CFG_IDENT = '      '
CFG = '''buffers {{
  size_kb: 32768
}}

data_sources {{
  config {{
    name: "android.heapprofd"
    heapprofd_config {{

      shmem_size_bytes: {shmem_size}
      sampling_interval_bytes: {interval}
{target_cfg}
{continuous_dump_cfg}
    }}
  }}
}}

duration_ms: {duration}
write_into_file: true
flush_timeout_ms: 30000
'''

CONTINUOUS_DUMP = """
      continuous_dump_config {{
        dump_phase_ms: 0
        dump_interval_ms: {dump_interval}
      }}
"""

PERFETTO_CMD = ('CFG=\'{cfg}\'; echo ${{CFG}} | '
                'perfetto --txt -c - -o '
                '/data/misc/perfetto-traces/profile-{user} -d')
IS_INTERRUPTED = False


def sigint_handler(sig, frame):
  global IS_INTERRUPTED
  IS_INTERRUPTED = True


def main(argv):
  parser = argparse.ArgumentParser()
  parser.add_argument(
      "-i",
      "--interval",
      help="Sampling interval. "
      "Default 4096 (4KiB)",
      type=int,
      default=4096)
  parser.add_argument(
      "-d",
      "--duration",
      help="Duration of profile (ms). "
      "Default 7 days.",
      type=int,
      default=604800000)
  parser.add_argument(
      "--no-start", help="Do not start heapprofd.", action='store_true')
  parser.add_argument(
      "-p",
      "--pid",
      help="Comma-separated list of PIDs to "
      "profile.",
      metavar="PIDS")
  parser.add_argument(
      "-n",
      "--name",
      help="Comma-separated list of process "
      "names to profile.",
      metavar="NAMES")
  parser.add_argument(
      "-c",
      "--continuous-dump",
      help="Dump interval in ms. 0 to disable continuous dump.",
      type=int,
      default=0)
  parser.add_argument(
      "--disable-selinux",
      action="store_true",
      help="Disable SELinux enforcement for duration of "
      "profile.")
  parser.add_argument(
      "--no-versions",
      action="store_true",
      help="Do not get version information about APKs.")
  parser.add_argument(
      "--no-running",
      action="store_true",
      help="Do not target already running processes.")
  parser.add_argument(
      "--no-startup",
      action="store_true",
      help="Do not target processes that start during "
      "the profile.")
  parser.add_argument(
      "--shmem-size",
      help="Size of buffer between client and "
      "heapprofd. Default 8MiB. Needs to be a power of two "
      "multiple of 4096, at least 8192.",
      type=int,
      default=8 * 1048576)
  parser.add_argument(
      "--block-client",
      help="When buffer is full, block the "
      "client to wait for buffer space. Use with caution as "
      "this can significantly slow down the client. "
      "This is the default",
      action="store_true")
  parser.add_argument(
      "--block-client-timeout",
      help="If --block-client is given, do not block any allocation for "
      "longer than this timeout (us).",
      type=int)
  parser.add_argument(
      "--no-block-client",
      help="When buffer is full, stop the "
      "profile early.",
      action="store_true")
  parser.add_argument(
      "--idle-allocations",
      help="Keep track of how many "
      "bytes were unused since the last dump, per "
      "callstack",
      action="store_true")
  parser.add_argument(
      "--dump-at-max",
      help="Dump the maximum memory usage"
      "rather than at the time of the dump.",
      action="store_true")
  parser.add_argument(
      "--simpleperf",
      action="store_true",
      help="Get simpleperf profile of heapprofd. This is "
      "only for heapprofd development.")
  parser.add_argument(
      "--trace-to-text-binary",
      help="Path to local trace to text. For debugging.")
  parser.add_argument(
      "--print-config",
      action="store_true",
      help="Print config instead of running. For debugging.")

  args = parser.parse_args()

  fail = False
  if args.block_client and args.no_block_client:
    print(
        "FATAL: Both block-client and no-block-client given.", file=sys.stderr)
    fail = True
  if args.pid is None and args.name is None:
    print("FATAL: Neither PID nor NAME given.", file=sys.stderr)
    fail = True
  if args.duration is None:
    print("FATAL: No duration given.", file=sys.stderr)
    fail = True
  if args.interval is None:
    print("FATAL: No interval given.", file=sys.stderr)
    fail = True
  if args.shmem_size % 4096:
    print("FATAL: shmem-size is not a multiple of 4096.", file=sys.stderr)
    fail = True
  if args.shmem_size < 8192:
    print("FATAL: shmem-size is less than 8192.", file=sys.stderr)
    fail = True
  if args.shmem_size & (args.shmem_size - 1):
    print("FATAL: shmem-size is not a power of two.", file=sys.stderr)
    fail = True

  target_cfg = ""
  if not args.no_block_client:
    target_cfg += "block_client: true\n"
  if args.block_client_timeout:
    target_cfg += "block_client_timeout_us: %s\n" % args.block_client_timeout
  if args.idle_allocations:
    target_cfg += "idle_allocations: true\n"
  if args.no_startup:
    target_cfg += "no_startup: true\n"
  if args.no_running:
    target_cfg += "no_running: true\n"
  if args.dump_at_max:
    target_cfg += "dump_at_max: true\n"
  if args.pid:
    for pid in args.pid.split(','):
      try:
        pid = int(pid)
      except ValueError:
        print("FATAL: invalid PID %s" % pid, file=sys.stderr)
        fail = True
      target_cfg += '{}pid: {}\n'.format(CFG_IDENT, pid)
  if args.name:
    for name in args.name.split(','):
      target_cfg += '{}process_cmdline: "{}"\n'.format(CFG_IDENT, name)

  if fail:
    parser.print_help()
    return 1

  trace_to_text_binary = args.trace_to_text_binary
  if trace_to_text_binary is None:
    platform = None
    if sys.platform.startswith('linux'):
      platform = 'linux'
    elif sys.platform.startswith('darwin'):
      platform = 'mac'
    else:
      print("Invalid platform: {}".format(sys.platform), file=sys.stderr)
      return 1

    trace_to_text_binary = load_trace_to_text(platform)

  continuous_dump_cfg = ""
  if args.continuous_dump:
    continuous_dump_cfg = CONTINUOUS_DUMP.format(
        dump_interval=args.continuous_dump)
  cfg = CFG.format(
      interval=args.interval,
      duration=args.duration,
      target_cfg=target_cfg,
      continuous_dump_cfg=continuous_dump_cfg,
      shmem_size=args.shmem_size)
  if not args.no_versions:
    cfg += PACKAGES_LIST_CFG

  if args.print_config:
    print(cfg)
    return 0

  if args.disable_selinux:
    enforcing = subprocess.check_output(['adb', 'shell', 'getenforce'])
    atexit.register(
        subprocess.check_call,
        ['adb', 'shell', 'su root setenforce %s' % enforcing])
    subprocess.check_call(['adb', 'shell', 'su root setenforce 0'])

  if not args.no_start:
    heapprofd_prop = subprocess.check_output(
        ['adb', 'shell', 'getprop persist.heapprofd.enable'])
    if heapprofd_prop.strip() != '1':
      subprocess.check_call(
          ['adb', 'shell', 'setprop persist.heapprofd.enable 1'])
      atexit.register(subprocess.check_call,
                      ['adb', 'shell', 'setprop persist.heapprofd.enable 0'])

  user = subprocess.check_output(['adb', 'shell', 'whoami']).strip()

  if args.simpleperf:
    subprocess.check_call([
        'adb', 'shell', 'mkdir -p /data/local/tmp/heapprofd_profile && '
        'cd /data/local/tmp/heapprofd_profile &&'
        '(nohup simpleperf record -g -p $(pidof heapprofd) 2>&1 &) '
        '> /dev/null'
    ])

  perfetto_pid = subprocess.check_output(
      ['adb', 'exec-out',
       PERFETTO_CMD.format(cfg=cfg, user=user)]).strip()
  try:
    int(perfetto_pid.strip())
  except ValueError:
    print("Failed to invoke perfetto: {}".format(perfetto_pid), file=sys.stderr)
    return 1

  old_handler = signal.signal(signal.SIGINT, sigint_handler)
  print("Profiling active. Press Ctrl+C to terminate.")
  print("You may disconnect your device.")
  exists = True
  device_connected = True
  while not device_connected or (exists and not IS_INTERRUPTED):
    exists = subprocess.call(
        ['adb', 'shell', '[ -d /proc/{} ]'.format(perfetto_pid)], **NOOUT) == 0
    device_connected = subprocess.call(['adb', 'shell', 'true'], **NOOUT) == 0
    time.sleep(1)
  signal.signal(signal.SIGINT, old_handler)
  if IS_INTERRUPTED:
    # Not check_call because it could have existed in the meantime.
    subprocess.call(['adb', 'shell', 'kill', '-INT', perfetto_pid])
  if args.simpleperf:
    subprocess.check_call(['adb', 'shell', 'killall', '-INT', 'simpleperf'])
    print("Waiting for simpleperf to exit.")
    while subprocess.call(
        ['adb', 'shell', '[ -f /proc/$(pidof simpleperf)/exe ]'], **NOOUT) == 0:
      time.sleep(1)
    subprocess.check_call(
        ['adb', 'pull', '/data/local/tmp/heapprofd_profile', '/tmp'])
    print("Pulled simpleperf profile to /tmp/heapprofd_profile")

  # Wait for perfetto cmd to return.
  while exists:
    exists = subprocess.call(
        ['adb', 'shell', '[ -d /proc/{} ]'.format(perfetto_pid)]) == 0
    time.sleep(1)

  subprocess.check_call([
      'adb', 'pull', '/data/misc/perfetto-traces/profile-{}'.format(user),
      '/tmp/profile'
  ],
                        stdout=NULL)
  trace_to_text_output = subprocess.check_output(
      [trace_to_text_binary, 'profile', '/tmp/profile'])
  profile_path = None
  for word in trace_to_text_output.split():
    if 'heap_profile-' in word:
      profile_path = word
  if profile_path is None:
    print("Could not find trace_to_text output path.", file=sys.stderr)
    return 1

  profile_files = os.listdir(profile_path)
  if not profile_files:
    print("No profiles generated", file=sys.stderr)
    print(
        "If this is unexpected, check "
        "https://docs.perfetto.dev/#/heapprofd?id=troubleshooting.",
        file=sys.stderr)
    return 1

  subprocess.check_call(
      ['gzip'] +
      [os.path.join(profile_path, x) for x in os.listdir(profile_path)])

  symlink_path = os.path.join(
      os.path.dirname(profile_path), "heap_profile-latest")
  if os.path.lexists(symlink_path):
    os.unlink(symlink_path)
  os.symlink(profile_path, symlink_path)
  shutil.copyfile('/tmp/profile', os.path.join(profile_path, 'raw-trace'))

  print("Wrote profiles to {} (symlink {})".format(profile_path, symlink_path))
  print("These can be viewed using pprof. Googlers: head to pprof/ and "
        "upload them.")


if __name__ == '__main__':
  sys.exit(main(sys.argv))
