/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
you may not use this file except in compliance with the License.
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
distributed under the License is distributed on an "AS IS" BASIS,
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
See the License for the specific language governing permissions and
|
limitations under the License.
|
==============================================================================*/
|
|
#ifndef TENSORFLOW_C_KERNELS_H_
|
#define TENSORFLOW_C_KERNELS_H_
|
|
#include "tensorflow/c/c_api.h"
|
|
#ifdef __cplusplus
|
extern "C" {
|
#endif
|
|
// --------------------------------------------------------------------------
|
// C API for TensorFlow Kernels.
|
//
|
// This API allows developers to register custom kernel implementations for
|
// TensorFlow.
|
//
|
// See c_api.h header comments for a discussion about API conventions.
|
//
|
// Users wishing to extend TensorFlow with new kernels will call
|
// `TF_NewKernelBuilder`. The resulting kernel builder can be registered with
|
// `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided
|
// kernels when necessary.
|
|
typedef struct TF_KernelBuilder TF_KernelBuilder;
|
typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
|
typedef struct TF_OpKernelContext TF_OpKernelContext;
|
|
// Allocates a new kernel builder and returns a pointer to it.
|
//
|
// If non-null, TensorFlow will call create_func when it needs to instantiate
|
// the kernel. The pointer returned by create_func will be passed to
|
// compute_func and delete_func, thereby functioning as a "this" pointer for
|
// referring to kernel instances.
|
//
|
// The TF_OpKernelConstruction pointer passed to create_func is owned by
|
// TensorFlow and will be deleted once create_func returns. It must not be used
|
// after this.
|
//
|
// When TensorFlow needs to perform a computation with this kernel, it will
|
// call compute_func. This function will receive the pointer returned by
|
// create_func (or null if no create_func was provided), along with the inputs
|
// to the computation.
|
//
|
// The TF_OpKernelContext pointer received by compute_func is owned by
|
// TensorFlow and will be deleted once compute_func returns. It must not be used
|
// after this.
|
//
|
// Finally, when TensorFlow no longer needs the kernel, it will call
|
// delete_func if one is provided. This function will receive the pointer
|
// returned in `create_func` or nullptr if no `create_func` was provided.
|
//
|
// The caller should pass the result of this function to
|
// TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for
|
// some reason, the kernel builder will not be registered, the caller should
|
// delete it with TF_DeleteKernelBuilder.
|
TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder(
|
const char* op_name, const char* device_name,
|
void* (*create_func)(TF_OpKernelConstruction*),
|
void (*compute_func)(void*, TF_OpKernelContext*),
|
void (*delete_func)(void*));
|
|
// Register the given kernel builder with the TensorFlow runtime. If
|
// registration fails, the given status will be populated.
|
//
|
// This call takes ownership of the `builder` pointer.
|
TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name,
|
TF_KernelBuilder* builder,
|
TF_Status* status);
|
|
// Deletes the given TF_KernelBuilder. This should be called only if the kernel
|
// builder is not registered with TensorFlow via TF_RegisterKernelBuilder.
|
TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
|
|
// --------------------------------------------------------------------------
|
// OpKernelContext routines
|
|
// TF_NumInputs returns the number of inputs available in ctx.
|
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
|
|
// TF_NumOutputs returns the number of outputs to be placed in *ctx by the
|
// kernel.
|
TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx);
|
|
// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is
|
// populated and its ownership is passed to the caller. In any other case,
|
// *tensor is not modified.
|
//
|
// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE.
|
TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i,
|
TF_Tensor** tensor, TF_Status* status);
|
|
// Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but
|
// TF_OK, ctx is left unmodified.
|
//
|
// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE.
|
TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i,
|
const TF_Tensor* tensor,
|
TF_Status* status);
|
|
// Notifies the given OpKernelConstruction that kernel construction has failed.
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_Failure(
|
TF_OpKernelConstruction* ctx, TF_Status* status);
|
|
// Notifies the given OpKernelContext that the kernel's compute function has
|
// failed.
|
TF_CAPI_EXPORT extern void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx,
|
TF_Status* status);
|
|
// Returns the expected output data type of the ith output. If i < 0 or
|
// i >= TF_NumOutputs(ctx), the program aborts.
|
TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType(
|
TF_OpKernelContext* ctx, int i);
|
|
// Returns the step ID of the given context.
|
TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
|
|
// Interprets the named kernel construction attribute as a TF_DataType and
|
// places it into *val. *status is set to TF_OK.
|
//
|
// If the attribute could not be found or could not be interpreted as
|
// TF_DataType, *status is populated with an error.
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrType(
|
TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* val,
|
TF_Status* status);
|
|
#ifdef __cplusplus
|
} /* end extern "C" */
|
#endif
|
|
#endif // TENSORFLOW_C_KERNELS_H_
|