From 652c5e46d94f13c29669bae96ac7e0515a9c6490 Mon Sep 17 00:00:00 2001 From: Cassandra Xia Date: Mon, 16 Apr 2018 21:05:23 -0700 Subject: [PATCH] Add TensorBoard flag `samples_per_plugin` that allows for custom specification of samples to keep per summary type. Previously, TensorBoard always downsampled summaries for OOM reasons but some users wanted the ability to keep all their summaries of some types. With this flag, --samples_per_plugin='histogram=2,images=0' keeps TensorBoard defaults for other summaries, restricts the number of histograms to 2, and keeps all image summaries. --- tensorboard/backend/application.py | 18 ++++++++++++++++-- tensorboard/program.py | 10 ++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/tensorboard/backend/application.py b/tensorboard/backend/application.py index d23e11599a..1d945d2fc6 100644 --- a/tensorboard/backend/application.py +++ b/tensorboard/backend/application.py @@ -71,6 +71,20 @@ _VALID_PLUGIN_RE = re.compile(r'^[A-Za-z0-9_.-]+$') +def tensor_size_guidance_from_flags(flags): + """Apply user per-summary size guidance overrides.""" + + tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE) + if not flags or not flags.samples_per_plugin: + return tensor_size_guidance + + for token in flags.samples_per_plugin.split(','): + k, v = token.strip().split('=') + tensor_size_guidance[k] = int(v) + + return tensor_size_guidance + + def standard_tensorboard_wsgi( logdir, purge_orphaned_data, @@ -90,13 +104,13 @@ def standard_tensorboard_wsgi( reload_interval: The interval at which the backend reloads more data in seconds. Zero means load once at startup; negative means never load. plugins: A list of constructor functions for TBPlugin subclasses. - path_prefix: A prefix of the path when app isn't served from root. db_uri: A String containing the URI of the SQL database for persisting data, or empty for memory-only mode. assets_zip_provider: See TBContext documentation for more information. If this value is not specified, this function will attempt to load the `tensorboard.default` module to use the default. This behavior might be removed in the future. + path_prefix: A prefix of the path when app isn't served from root. window_title: A string specifying the the window title. max_reload_threads: The max number of threads that TensorBoard can use to reload runs. Not relevant for db mode. Each thread reloads one run @@ -110,7 +124,7 @@ def standard_tensorboard_wsgi( assets_zip_provider = default.get_assets_zip_provider() multiplexer = event_multiplexer.EventMultiplexer( size_guidance=DEFAULT_SIZE_GUIDANCE, - tensor_size_guidance=DEFAULT_TENSOR_SIZE_GUIDANCE, + tensor_size_guidance=tensor_size_guidance_from_flags(flags), purge_orphaned_data=purge_orphaned_data, max_reload_threads=max_reload_threads) db_module, db_connection_provider = get_database_info(db_uri) diff --git a/tensorboard/program.py b/tensorboard/program.py index 952cf99828..ae9ffa24bd 100644 --- a/tensorboard/program.py +++ b/tensorboard/program.py @@ -128,6 +128,16 @@ 'The max number of threads that TensorBoard can use to reload runs. Not ' 'relevant for db mode. Each thread reloads one run at a time.') +tf.flags.DEFINE_string( + 'samples_per_plugin', '', 'An optional comma separated list of ' + 'plugin_name=num_samples pairs to explicitly specify how many samples to ' + 'keep per tag for that plugin. For unspecified plugins, TensorBoard ' + 'randomly downsamples logged summaries to reasonable values to prevent ' + 'out-of-memory errors for long running jobs. This flag allows fine control ' + 'over that downsampling. Note that 0 means keep all samples of that type. ' + 'For instance, "scalars=500,images=0" keeps 500 scalars and all images. ' + 'Most users should not need to set this flag.') + FLAGS = tf.flags.FLAGS