/*
|
* Copyright (C) 2018 The Android Open Source Project
|
*
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*/
|
|
#ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
|
#define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
|
|
#include <functional>
|
#include <vector>
|
|
#include "utils/flatbuffers.h"
|
#include "utils/strings/stringpiece.h"
|
#include "utils/variant.h"
|
#include "flatbuffers/reflection_generated.h"
|
|
#ifdef __cplusplus
|
extern "C" {
|
#endif
|
#include "lauxlib.h"
|
#include "lua.h"
|
#include "lualib.h"
|
#ifdef __cplusplus
|
}
|
#endif
|
|
namespace libtextclassifier3 {
|
|
static constexpr const char *kLengthKey = "__len";
|
static constexpr const char *kPairsKey = "__pairs";
|
static constexpr const char *kIndexKey = "__index";
|
|
// Casts to the lua user data type.
|
template <typename T>
|
void *AsUserData(const T *value) {
|
return static_cast<void *>(const_cast<T *>(value));
|
}
|
template <typename T>
|
void *AsUserData(const T value) {
|
return reinterpret_cast<void *>(value);
|
}
|
|
// Retrieves up-values.
|
template <typename T>
|
T FromUpValue(const int index, lua_State *state) {
|
return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
|
}
|
|
class LuaEnvironment {
|
public:
|
// Wrapper for handling an iterator.
|
class Iterator {
|
public:
|
virtual ~Iterator() {}
|
static int NextCallback(lua_State *state);
|
static int LengthCallback(lua_State *state);
|
static int ItemCallback(lua_State *state);
|
static int IteritemsCallback(lua_State *state);
|
|
// Called when the next element of an iterator is fetched.
|
virtual int Next(lua_State *state) const = 0;
|
|
// Called when the length of the iterator is queried.
|
virtual int Length(lua_State *state) const = 0;
|
|
// Called when an item is queried.
|
virtual int Item(lua_State *state) const = 0;
|
|
// Called when a new iterator is started.
|
virtual int Iteritems(lua_State *state) const = 0;
|
|
protected:
|
static constexpr int kIteratorArgId = 1;
|
};
|
|
template <typename T>
|
class ItemIterator : public Iterator {
|
public:
|
void NewIterator(StringPiece name, const T *items, lua_State *state) const {
|
lua_newtable(state);
|
luaL_newmetatable(state, name.data());
|
lua_pushlightuserdata(state, AsUserData(this));
|
lua_pushlightuserdata(state, AsUserData(items));
|
lua_pushcclosure(state, &Iterator::ItemCallback, 2);
|
lua_setfield(state, -2, kIndexKey);
|
lua_pushlightuserdata(state, AsUserData(this));
|
lua_pushlightuserdata(state, AsUserData(items));
|
lua_pushcclosure(state, &Iterator::LengthCallback, 2);
|
lua_setfield(state, -2, kLengthKey);
|
lua_pushlightuserdata(state, AsUserData(this));
|
lua_pushlightuserdata(state, AsUserData(items));
|
lua_pushcclosure(state, &Iterator::IteritemsCallback, 2);
|
lua_setfield(state, -2, kPairsKey);
|
lua_setmetatable(state, -2);
|
}
|
|
int Iteritems(lua_State *state) const override {
|
lua_pushlightuserdata(state, AsUserData(this));
|
lua_pushlightuserdata(
|
state, lua_touserdata(state, lua_upvalueindex(kItemsArgId)));
|
lua_pushnumber(state, 0);
|
lua_pushcclosure(state, &Iterator::NextCallback, 3);
|
return /*num results=*/1;
|
}
|
|
int Length(lua_State *state) const override {
|
lua_pushinteger(state, FromUpValue<T *>(kItemsArgId, state)->size());
|
return /*num results=*/1;
|
}
|
|
int Next(lua_State *state) const override {
|
return Next(FromUpValue<T *>(kItemsArgId, state),
|
lua_tointeger(state, lua_upvalueindex(kIterValueArgId)),
|
state);
|
}
|
|
int Next(const T *items, const int64 pos, lua_State *state) const {
|
if (pos >= items->size()) {
|
return 0;
|
}
|
|
// Update iterator value.
|
lua_pushnumber(state, pos + 1);
|
lua_replace(state, lua_upvalueindex(3));
|
|
// Push key.
|
lua_pushinteger(state, pos + 1);
|
|
// Push item.
|
return 1 + Item(items, pos, state);
|
}
|
|
int Item(lua_State *state) const override {
|
const T *items = FromUpValue<T *>(kItemsArgId, state);
|
switch (lua_type(state, -1)) {
|
case LUA_TNUMBER: {
|
// Lua is one based, so adjust the index here.
|
const int64 index =
|
static_cast<int64>(lua_tonumber(state, /*idx=*/-1)) - 1;
|
if (index < 0 || index >= items->size()) {
|
TC3_LOG(ERROR) << "Invalid index: " << index;
|
lua_error(state);
|
return 0;
|
}
|
return Item(items, index, state);
|
}
|
case LUA_TSTRING: {
|
size_t key_length = 0;
|
const char *key = lua_tolstring(state, /*idx=*/-1, &key_length);
|
return Item(items, StringPiece(key, key_length), state);
|
}
|
default:
|
TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(state, -1);
|
lua_error(state);
|
return 0;
|
}
|
}
|
|
virtual int Item(const T *items, const int64 pos,
|
lua_State *state) const = 0;
|
|
virtual int Item(const T *items, StringPiece key, lua_State *state) const {
|
TC3_LOG(ERROR) << "Unexpected key access: " << key.ToString();
|
lua_error(state);
|
return 0;
|
}
|
|
protected:
|
static constexpr int kItemsArgId = 2;
|
static constexpr int kIterValueArgId = 3;
|
};
|
|
virtual ~LuaEnvironment();
|
LuaEnvironment();
|
|
// Compile a lua snippet into binary bytecode.
|
// NOTE: The compiled bytecode might not be compatible across Lua versions
|
// and platforms.
|
bool Compile(StringPiece snippet, std::string *bytecode);
|
|
typedef int (*CallbackHandler)(lua_State *);
|
|
// Loads default libraries.
|
void LoadDefaultLibraries();
|
|
// Provides a callback to Lua.
|
template <typename T, int (T::*handler)()>
|
void Bind() {
|
lua_pushlightuserdata(state_, static_cast<void *>(this));
|
lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
|
}
|
|
// Setup a named table that callsback whenever a member is accessed.
|
// This allows to lazily provide required information to the script.
|
template <typename T, int (T::*handler)()>
|
void BindTable(const char *name) {
|
lua_newtable(state_);
|
luaL_newmetatable(state_, name);
|
lua_pushlightuserdata(state_, static_cast<void *>(this));
|
lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
|
lua_setfield(state_, -2, kIndexKey);
|
lua_setmetatable(state_, -2);
|
}
|
|
void PushValue(const Variant &value);
|
|
// Reads a string from the stack.
|
StringPiece ReadString(const int index) const;
|
|
// Pushes a string to the stack.
|
void PushString(const StringPiece str);
|
|
// Pushes a flatbuffer to the stack.
|
void PushFlatbuffer(const reflection::Schema *schema,
|
const flatbuffers::Table *table);
|
|
// Reads a flatbuffer from the stack.
|
int ReadFlatbuffer(ReflectiveFlatbuffer *buffer);
|
|
// Runs a closure in protected mode.
|
// `func`: closure to run in protected mode.
|
// `num_lua_args`: number of arguments from the lua stack to process.
|
// `num_results`: number of result values pushed on the stack.
|
int RunProtected(const std::function<int()> &func, const int num_args = 0,
|
const int num_results = 0);
|
|
lua_State *state() const { return state_; }
|
|
protected:
|
lua_State *state_;
|
|
private:
|
// Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
|
static void PushFlatbuffer(const char *name, const reflection::Schema *schema,
|
const reflection::Object *type,
|
const flatbuffers::Table *table, lua_State *state);
|
static int GetFieldCallback(lua_State *state);
|
static int GetField(const reflection::Schema *schema,
|
const reflection::Object *type,
|
const flatbuffers::Table *table, lua_State *state);
|
|
template <typename T, int (T::*handler)()>
|
static int Dispatch(lua_State *state) {
|
T *env = FromUpValue<T *>(1, state);
|
return ((*env).*handler)();
|
}
|
};
|
|
bool Compile(StringPiece snippet, std::string *bytecode);
|
|
} // namespace libtextclassifier3
|
|
#endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
|