Refactor Script registration (#11280) bc6f965d1f

Co-authored-by: Philip Chung <philterdesign@gmail.com>
This commit is contained in:
philter
2025-12-16 19:52:05 +00:00
parent fbd07113b2
commit c52a714032
9 changed files with 255 additions and 168 deletions

View File

@@ -1 +1 @@
752595c5904a484d102b2a9a56bbea9921a3a00d
bc6f965d1f8c5e298073dc65b6298806baaaa9b9

View File

@@ -111,6 +111,23 @@ public:
virtual void registrationComplete(int ref) {}
virtual Span<uint8_t> moduleBytecode() { return Span<uint8_t>(); }
virtual bool isProtocolScript() = 0;
virtual bool verified() const { return false; }
void addMissingDependency(std::string name) { m_dependencies.insert(name); }
void clearMissingDependency(std::string name)
{
auto it = m_dependencies.find(name);
if (it != m_dependencies.end())
{
m_dependencies.erase(it);
}
}
std::unordered_set<std::string> missingDependencies()
{
return m_dependencies;
}
private:
std::unordered_set<std::string> m_dependencies;
};
class ScriptAsset : public ScriptAssetBase,
@@ -122,7 +139,7 @@ public:
#ifdef WITH_RIVE_SCRIPTING
friend class ScriptAssetImporter;
bool verified() const { return m_verified; }
bool verified() const override { return m_verified; }
Span<uint8_t> moduleBytecode() override { return m_bytecode; }
#endif

View File

@@ -33,6 +33,7 @@
#include <chrono>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <string>
#include <vector>
@@ -745,27 +746,30 @@ public:
virtual void printEndLine() = 0;
virtual int pCall(lua_State* state, int nargs, int nresults) = 0;
void queuePendingModule(ModuleDetails* moduleDetails);
void clearPendingModule(const std::string& name);
// Add a module to be registered later via performRegistration()
void addModule(ModuleDetails* moduleDetails);
// Perform registration of all added modules, handling dependencies and
// retries
void performRegistration(lua_State* state);
// Called when a module is required but not found during registration
void recordMissingDependency(const std::string& requiringModule,
const std::string& missingModule);
private:
bool tryRegisterModule(lua_State* state,
ModuleDetails* moduleDetails,
int& functionRef);
void retryPendingModules(lua_State* state);
bool tryRegisterModule(lua_State* state, ModuleDetails* moduleDetails);
void sortNextModule(ModuleDetails* module,
std::vector<ModuleDetails*>* pendingModules,
std::vector<ModuleDetails*>* sortedModules,
std::unordered_set<ModuleDetails*>* visitedModules);
// Called when a module successfully registers
void onModuleRegistered(ModuleDetails* moduleDetails);
private:
Factory* m_factory;
std::vector<ModuleDetails*> m_pendingModules;
std::vector<ModuleDetails*> m_modulesToRegister;
std::unordered_map<std::string, ModuleDetails*> m_moduleLookup;
std::unordered_set<ModuleDetails*> m_pendingModules;
};
class ScriptingVM

View File

@@ -320,15 +320,6 @@ ScriptedArtboard::~ScriptedArtboard()
m_scriptReffedArtboard = nullptr;
}
void ScriptedArtboard::cleanupDataRef(lua_State* L)
{
if (m_dataRef != 0)
{
lua_unref(L, m_dataRef);
m_dataRef = 0;
}
}
static int node_index(lua_State* L)
{
int atom;
@@ -569,22 +560,9 @@ static int node_namecall(lua_State* L)
return 0;
}
static void scripted_artboard_dtor(lua_State* L, void* data)
{
ScriptedArtboard* artboard = static_cast<ScriptedArtboard*>(data);
// Clean up m_dataRef before calling the C++ destructor
// (we can't do this in ~ScriptedArtboard() because it doesn't have access
// to lua_State*)
artboard->cleanupDataRef(L);
// Call the C++ destructor
artboard->~ScriptedArtboard();
}
int luaopen_rive_artboards(lua_State* L)
{
lua_register_rive<ScriptedArtboard>(L);
// Override the default destructor to clean up m_dataRef first
lua_setuserdatadtor(L, ScriptedArtboard::luaTag, scripted_artboard_dtor);
lua_pushcfunction(L, artboard_index, nullptr);
lua_setfield(L, -2, "__index");

View File

@@ -5,6 +5,9 @@
#include <unordered_map>
#include <unordered_set>
#include <string>
#include <queue>
#include <vector>
#include <algorithm>
using namespace rive;
@@ -262,6 +265,17 @@ static int lua_requireinternal(lua_State* L, const char* requirerChunkname)
return 1;
}
// Record missing dependency if we're registering a module
if (requirerChunkname)
{
ScriptingContext* context =
static_cast<ScriptingContext*>(lua_getthreaddata(L));
if (context)
{
context->recordMissingDependency(requirerChunkname, path);
}
}
luaL_error(L, "require could not find a script named %s", path);
return 0;
}
@@ -458,14 +472,19 @@ void ScriptingVM::dumpStack(lua_State* state) { dump_stack(state); }
void ScriptingContext::addModule(ModuleDetails* moduleDetails)
{
m_modulesToRegister.push_back(moduleDetails);
m_moduleLookup[moduleDetails->moduleName()] = moduleDetails;
}
bool ScriptingContext::tryRegisterModule(lua_State* state,
ModuleDetails* moduleDetails,
int& functionRef)
ModuleDetails* moduleDetails)
{
if (!moduleDetails->verified())
{
return false;
}
const std::string& name = moduleDetails->moduleName();
functionRef = 0;
bool registerSuccess = false;
int functionRef = 0;
if (moduleDetails->isProtocolScript())
{
if (ScriptingVM::registerScript(state,
@@ -478,7 +497,7 @@ bool ScriptingContext::tryRegisterModule(lua_State* state,
functionRef = lua_ref(state, -1);
}
lua_pop(state, 1);
return true;
registerSuccess = true;
}
}
else
@@ -487,104 +506,138 @@ bool ScriptingContext::tryRegisterModule(lua_State* state,
name.c_str(),
moduleDetails->moduleBytecode()))
{
return true;
registerSuccess = true;
}
}
if (registerSuccess)
{
moduleDetails->registrationComplete(functionRef);
onModuleRegistered(moduleDetails);
return true;
}
return false;
}
void ScriptingContext::retryPendingModules(lua_State* state)
void ScriptingContext::performRegistration(lua_State* state)
{
bool anyRetried = true;
while (anyRetried)
// Loop over all of the modules once. We need do a tryRegister
// pass on each module in order to determine if it has any
// required dependencies
for (ModuleDetails* moduleDetails : m_modulesToRegister)
{
anyRetried = false;
// Track which modules we've tried in this iteration to avoid infinite
// loops
std::unordered_set<std::string> triedThisCycle;
auto currentPending = m_pendingModules;
for (ModuleDetails* moduleDetails : currentPending)
// Skip if already registered
if (checkRegisteredModules(state,
moduleDetails->moduleName().c_str()) == 1)
{
const std::string& name = moduleDetails->moduleName();
continue;
}
tryRegisterModule(state, moduleDetails);
}
// Skip if we've already tried this module in this iteration
if (triedThisCycle.find(name) != triedThisCycle.end())
{
continue;
}
// If any modules had dependencies, resolve their registration order
// and try registering again
if (!m_pendingModules.empty())
{
std::vector<ModuleDetails*> pendingModules;
for (auto module : m_pendingModules)
{
pendingModules.push_back(module);
}
std::vector<ModuleDetails*> sortedModules;
std::unordered_set<ModuleDetails*> visitedModules;
triedThisCycle.insert(name);
ModuleDetails* module = pendingModules.back();
pendingModules.pop_back();
sortNextModule(module,
&pendingModules,
&sortedModules,
&visitedModules);
int functionRef = 0;
if (tryRegisterModule(state, moduleDetails, functionRef))
{
// Successfully registered, remove from pending list
clearPendingModule(name);
// Register modules in sorted order
for (ModuleDetails* moduleDetails : sortedModules)
{
tryRegisterModule(state, moduleDetails);
}
}
moduleDetails->registrationComplete(
moduleDetails->isProtocolScript() ? functionRef : 0);
m_modulesToRegister.clear();
m_pendingModules.clear();
}
anyRetried = true;
break;
}
void ScriptingContext::sortNextModule(
ModuleDetails* module,
std::vector<ModuleDetails*>* pendingModules,
std::vector<ModuleDetails*>* sortedModules,
std::unordered_set<ModuleDetails*>* visitedModules)
{
// If already visited, skip
if (visitedModules->find(module) != visitedModules->end())
{
return;
}
auto dependencies = module->missingDependencies();
for (const auto& dependencyName : dependencies)
{
auto lookupIt = m_moduleLookup.find(dependencyName);
if (lookupIt != m_moduleLookup.end())
{
ModuleDetails* dependencyModule = lookupIt->second;
// Recursively process the dependency
sortNextModule(dependencyModule,
pendingModules,
sortedModules,
visitedModules);
}
}
if (std::find(sortedModules->begin(), sortedModules->end(), module) ==
sortedModules->end())
{
sortedModules->push_back(module);
}
visitedModules->insert(module);
if (!pendingModules->empty())
{
ModuleDetails* nextModule = pendingModules->back();
pendingModules->pop_back();
sortNextModule(nextModule,
pendingModules,
sortedModules,
visitedModules);
}
}
void ScriptingContext::recordMissingDependency(
const std::string& requiringModule,
const std::string& missingModule)
{
if (!requiringModule.empty())
{
ModuleDetails* moduleDetails = m_moduleLookup[requiringModule];
if (moduleDetails != nullptr)
{
moduleDetails->addMissingDependency(missingModule);
m_pendingModules.insert(moduleDetails);
}
}
}
void ScriptingContext::performRegistration(lua_State* state)
void ScriptingContext::onModuleRegistered(ModuleDetails* moduleDetails)
{
// Try to register all modules
std::unordered_set<std::string> tried;
bool anyRegistered = true;
while (anyRegistered)
for (ModuleDetails* module : m_modulesToRegister)
{
anyRegistered = false;
for (ModuleDetails* moduleDetails : m_modulesToRegister)
if (!module->missingDependencies().empty())
{
const std::string& name = moduleDetails->moduleName();
// Skip if already registered
if (checkRegisteredModules(state, name.c_str()) == 1)
{
continue;
}
// Skip if we've already tried this module
if (tried.find(name) != tried.end())
{
continue;
}
tried.insert(name);
// Try to register the module
int functionRef = 0;
if (tryRegisterModule(state, moduleDetails, functionRef))
{
// Successfully registered
anyRegistered = true;
moduleDetails->registrationComplete(
moduleDetails->isProtocolScript() ? functionRef : 0);
this->retryPendingModules(state);
break;
}
else
{
// Registration failed - likely missing dependencies
// Queue for retry
queuePendingModule(moduleDetails);
}
moduleDetails->clearMissingDependency(moduleDetails->moduleName());
}
}
// Clear the modules list after registration attempt
m_modulesToRegister.clear();
auto it = m_pendingModules.find(moduleDetails);
if (it != m_pendingModules.end())
{
m_pendingModules.erase(it);
}
}
bool ScriptingVM::registerScript(lua_State* state,
@@ -651,22 +704,6 @@ bool ScriptingVM::registerScript(const char* name, Span<uint8_t> bytecode)
return registerScript(m_state, name, bytecode);
}
void ScriptingContext::queuePendingModule(ModuleDetails* moduleDetails)
{
m_pendingModules.push_back(moduleDetails);
}
void ScriptingContext::clearPendingModule(const std::string& name)
{
m_pendingModules.erase(std::remove_if(m_pendingModules.begin(),
m_pendingModules.end(),
[&name](ModuleDetails* module) {
return module->moduleName() ==
name;
}),
m_pendingModules.end());
}
int CPPRuntimeScriptingContext::pCall(lua_State* state, int nargs, int nresults)
{
// calculate stack position for message handler

Binary file not shown.

View File

@@ -3,7 +3,6 @@
#include "scripting_test_utilities.hpp"
#include "rive/animation/state_machine_instance.hpp"
#include "rive/lua/rive_lua_libs.hpp"
#include "rive/viewmodel/viewmodel_instance_number.hpp"
#include "rive/viewmodel/viewmodel_instance_string.hpp"
#include "rive_file_reader.hpp"
@@ -405,42 +404,4 @@ TEST_CASE("scripted string converter", "[silver]")
artboard->draw(renderer.get());
CHECK(silver.matches("script_string_converter"));
}
TEST_CASE("scripted data converter using multi chain requires", "[silver]")
{
rive::SerializingFactory silver;
auto file = ReadRiveFile("assets/script_dependency_test.riv", &silver);
auto artboard = file->artboardNamed("Artboard");
silver.frameSize(artboard->width(), artboard->height());
REQUIRE(artboard != nullptr);
auto stateMachine = artboard->stateMachineAt(0);
int viewModelId = artboard.get()->viewModelId();
auto vmi = viewModelId == -1
? file->createViewModelInstance(artboard.get())
: file->createViewModelInstance(viewModelId, 0);
stateMachine->bindViewModelInstance(vmi);
stateMachine->advanceAndApply(0.1f);
auto renderer = silver.makeRenderer();
artboard->draw(renderer.get());
rive::ViewModelInstanceNumber* num =
vmi->propertyValue("InputValue1")->as<rive::ViewModelInstanceNumber>();
REQUIRE(num != nullptr);
int counter = 0;
int frames = 30;
for (int i = 0; i < frames; i++)
{
num->propertyValue(counter);
silver.addFrame();
stateMachine->advanceAndApply(0.016f);
artboard->draw(renderer.get());
counter += 5;
}
CHECK(silver.matches("script_converter_with_dependency"));
}

View File

@@ -0,0 +1,90 @@
#include "catch.hpp"
#include "rive_file_reader.hpp"
#include "rive/animation/state_machine_instance.hpp"
#include "rive/lua/rive_lua_libs.hpp"
#include "rive/assets/script_asset.hpp"
#include "rive/viewmodel/viewmodel_instance_number.hpp"
#include "rive/viewmodel/viewmodel_instance_string.hpp"
#include "utils/serializing_factory.hpp"
using namespace rive;
TEST_CASE("scripted data converter number using multi chain requires",
"[silver]")
{
rive::SerializingFactory silver;
auto file = ReadRiveFile("assets/script_dependency_test.riv", &silver);
auto artboard = file->artboardNamed("Artboard");
silver.frameSize(artboard->width(), artboard->height());
REQUIRE(artboard != nullptr);
auto stateMachine = artboard->stateMachineAt(0);
int viewModelId = artboard.get()->viewModelId();
auto vmi = viewModelId == -1
? file->createViewModelInstance(artboard.get())
: file->createViewModelInstance(viewModelId, 0);
stateMachine->bindViewModelInstance(vmi);
stateMachine->advanceAndApply(0.1f);
auto renderer = silver.makeRenderer();
artboard->draw(renderer.get());
rive::ViewModelInstanceNumber* num =
vmi->propertyValue("InputValue1")->as<rive::ViewModelInstanceNumber>();
REQUIRE(num != nullptr);
int counter = 0;
int frames = 30;
for (int i = 0; i < frames; i++)
{
num->propertyValue(counter);
silver.addFrame();
stateMachine->advanceAndApply(0.016f);
artboard->draw(renderer.get());
counter += 5;
}
CHECK(silver.matches("script_converter_with_dependency"));
}
TEST_CASE("scripted data converter string using multi chain requires",
"[silver]")
{
rive::SerializingFactory silver;
auto file = ReadRiveFile("assets/script_dependency_test2.riv", &silver);
auto artboard = file->artboardNamed("Artboard");
silver.frameSize(artboard->width(), artboard->height());
REQUIRE(artboard != nullptr);
auto stateMachine = artboard->stateMachineAt(0);
int viewModelId = artboard.get()->viewModelId();
auto vmi = viewModelId == -1
? file->createViewModelInstance(artboard.get())
: file->createViewModelInstance(viewModelId, 0);
stateMachine->bindViewModelInstance(vmi);
stateMachine->advanceAndApply(0.1f);
auto renderer = silver.makeRenderer();
artboard->draw(renderer.get());
rive::ViewModelInstanceString* str =
vmi->propertyValue("InputString")->as<rive::ViewModelInstanceString>();
REQUIRE(str != nullptr);
std::vector<std::string> values = {"Hello world!",
"1,2,3",
"rive scripting",
"testing testing testing",
"Script Data Converter"};
for (int i = 0; i < values.size(); i++)
{
str->propertyValue(values[i]);
silver.addFrame();
stateMachine->advanceAndApply(0.016f);
artboard->draw(renderer.get());
}
CHECK(silver.matches("script_converter_with_dependency_2"));
}