[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable TensorFloat32 with XLA #44887

Open
n2cholas opened this issue Nov 15, 2020 · 0 comments
Open

Enable TensorFloat32 with XLA #44887

n2cholas opened this issue Nov 15, 2020 · 0 comments
Assignees
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower type:feature Feature requests

Comments

@n2cholas
Copy link
Contributor

System information

  • TensorFlow version (you are using): 2.5.0-dev20201115
  • Are you willing to contribute it (Yes/No): No (don't know how to)

Describe the feature and the current behavior/state.
Currently, TF32 works in the normal run time but not with XLA.

Will this change the current api? How?
No.

Who will benefit with this feature?
Anyone using Ampere GPUs and XLA compilation.

Any Other info.
I checked the TF32 was not being used with XLA using the following script:

import tensorflow as tf
tf.debugging.set_log_device_placement(True)
tf.config.experimental.enable_tensor_float_32_execution(True)

@tf.function(experimental_compile=True)
def f(x):
    return x @ x
with tf.device("/GPU:0"):
    x = tf.cast(tf.random.uniform(shape=(16,16)), tf.float32)
f(x)

I ran this on an A100 with dlprof. With experimental_compile=False, the following kernel ran:
image

With experimental_compile=False and tf.config.experimental.enable_tensor_float_32_execution(False), the following kernel ran:
image

With experimental_compile=True and tf.config.experimental.enable_tensor_float_32_execution(True), the same kernel ran:
image

This indicates to me that the TF32 is not enabled with XLA. Please let me know if I'm doing something wrong. Thanks!

@n2cholas n2cholas added the type:feature Feature requests label Nov 15, 2020
@ravikyram ravikyram added the comp:xla XLA label Nov 17, 2020
@ravikyram ravikyram assigned ymodak and unassigned ravikyram Nov 17, 2020
@ymodak ymodak assigned r4nt and unassigned ymodak Nov 17, 2020
@ymodak ymodak added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Nov 17, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

4 participants