#!/usr/bin/env python
# Copyright 2017 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Grabs before/after memory dumps using the devtools remote protocol.

To use it you first start Chrome with remote debugging enabled then run the
script which will take a memory dump, wait for you to press enter, take
another memory dump and finally save a trace file. For example:

On OSX:
$ /Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome \
    --remote-debugging-port=9222 \\
    --memlog=all --memlog-sampling --memlog-stack-mode=pseudo \\
$ ./tracing/bin/memory_infra_remote_dump --port=9222
...
[Press enter to stop tracing]
...
/var/folders/18/gl6q632j20nc_tw5g9l03dhc007g45/T/trace_20s191.json: 835 KB

On Android:
$ ./build/android/adb_chrome_public_command_line \\
    --memlog=all --memlog-sampling --memlog-stack-mode=pseudo \\
    --enable-remote-debugging
$ ./build/android/adb_run_chrome_public
$ adb forward tcp:1234 localabstract:chrome_devtools_remote
$ ./third_party/catapult/tracing/bin/memory_infra_remote_dump --port=1234
...
[Press enter to stop tracing]
...
/var/folders/18/gl6q632j20nc_tw5g9l03dhc007g45/T/trace_20s191.json: 835 KB
"""

import argparse
import json
import os
import requests
import sys
import tempfile
import time

try:
  import websocket
except ImportError:
  print 'Please run: pip install --user websocket-client'
  sys.exit(1)


class TracingDevtoolsClient(object):
  def __init__(self, host, port):
    r = requests.get('http://%s:%s/json/version' % (host, port))
    url = r.json()['webSocketDebuggerUrl']
    print 'Connecting to ' + url
    self.ws = websocket.create_connection(url)
    self.cmd = 0

  def send(self, method, params={}):
    self.cmd += 1
    self.ws.send(
        json.dumps({'id': self.cmd, 'method': method, 'params': params}))
    resp = self.recv()
    assert resp['id'] == self.cmd
    return resp.get('result', {})

  def recv(self):
    return json.loads(self.ws.recv())

  def req_memory_dump(self):
    print 'Requesting memory dump...',
    resp = self.send('Tracing.requestMemoryDump')
    assert resp['success'] == True
    print ' ...done'

  def dump(self, trace_fd):
    trace_config = {
        'excludedCategories': ['*'],
        'includedCategories': ['disabled-by-default-memory-infra'],
        'memoryDumpConfig': {'triggers': []}
    }
    print 'Starting trace with trace_config', trace_config
    params = {'traceConfig': trace_config, 'transferMode': 'ReturnAsStream'}
    self.send('Tracing.start', params)
    self.req_memory_dump()

    if sys.stdin.isatty():
      while True:
        try:
          print '[Press enter to trigger a new dump, q to finish the trace]'
          cmd = raw_input()
        except KeyboardInterrupt:
          break
        if cmd == 'q':
          break
        self.req_memory_dump()

    self.send('Tracing.end')

    # Wait for trace completion
    print 'Flushing trace'
    resp = self.recv()
    assert resp['method'] == 'Tracing.tracingComplete'
    stream_handle = resp['params']['stream']

    # Read back the trace stream
    resp = {'eof': False}
    while not resp['eof']:
      resp = self.send('IO.read', {'handle': stream_handle})
      trace_fd.write(resp['data'].encode('utf-8'))

    self.send('IO.close', {'handle': stream_handle})
    trace_fd.close()


if __name__ == '__main__':
  parser = argparse.ArgumentParser(
      description=__doc__,
      formatter_class=argparse.RawTextHelpFormatter)
  parser.add_argument('--host', default='localhost')
  parser.add_argument('--port', '-p', default=9222)
  parser.add_argument('--output-trace', '-o', default=None)
  args = parser.parse_args()

  if args.output_trace is None:
    trace_fd = tempfile.NamedTemporaryFile(prefix='trace_', suffix='.json',
                                           delete=False)
  else:
    trace_fd = open(args.output_trace, 'wb')

  cli = TracingDevtoolsClient(args.host, args.port)
  cli.dump(trace_fd)
  print '\n%s: %d KB' % (trace_fd.name, os.stat(trace_fd.name).st_size / 1000)
