Source code for etudes.decorators
import tensorflow as tf
import tensorflow_probability as tfp
# import tensorflow.keras.backend as K
from functools import wraps
[docs]def negate(fn):
@wraps(fn)
def new_fn(*args, **kwargs):
return -fn(*args, **kwargs)
return new_fn
[docs]def unbatch(fn):
@wraps(fn)
def new_fn(input):
batch_input = tf.expand_dims(input, axis=0)
batch_output = fn(batch_input)
output = tf.squeeze(batch_output)
return output
return new_fn
[docs]def value_and_gradient(value_fn):
@wraps(value_fn)
def value_and_gradient_fn(x):
# Equivalent to `tfp.math.value_and_gradient(value_fn, x)`, with the
# only difference that the gradients preserve their `dtype`
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(x)
val = value_fn(x)
grad = tape.gradient(val, x)
return val, grad
return value_and_gradient_fn
[docs]def numpy_io(fn):
@wraps(fn)
def new_fn(input):
input_tensor = tf.convert_to_tensor(input)
output_tensor = fn(input_tensor)
return [output.numpy() for output in output_tensor]
return new_fn