Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 78 additions & 12 deletions registry/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/meshery/meshkit/errors"
"github.com/meshery/meshkit/files"
"github.com/meshery/meshkit/generators"
"github.com/meshery/meshkit/generators/github"
"github.com/meshery/meshkit/generators/models"
"github.com/meshery/meshkit/models/meshmodel/entity"
"github.com/meshery/meshkit/utils"
Expand All @@ -31,11 +32,86 @@ import (
"github.com/meshery/schemas/models/v1beta1/subcategory"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/semaphore"
"golang.org/x/sync/singleflight"
"google.golang.org/api/sheets/v4"
)

var modelToCompGenerateTracker = store.NewGenericThreadSafeStore[compGenerateTracker]()

type generatorFactory func(registrant, url, packageName string) (models.PackageManager, error)

type packageFetcher struct {
newGenerator generatorFactory
cache sync.Map
fetchGroup singleflight.Group
}

func newPackageFetcher(newGenerator generatorFactory) *packageFetcher {
return &packageFetcher{
newGenerator: newGenerator,
}
}

func packageCacheKey(registrant, sourceURL, modelName string) string {
normalizedRegistrant := utils.ReplaceSpacesAndConvertToLowercase(registrant)
if normalizedRegistrant == artifactHub {
return fmt.Sprintf("%s\x00%s\x00%s", normalizedRegistrant, sourceURL, utils.ReplaceSpacesAndConvertToLowercase(modelName))
}

return fmt.Sprintf("%s\x00%s", normalizedRegistrant, sourceURL)
}

// GitHub packages derive generated component metadata from the model name, so
// reuse the fetched content but return a per-model copy with the requested name.
func packageForModel(registrant, modelName string, pkg models.Package) models.Package {
if utils.ReplaceSpacesAndConvertToLowercase(registrant) != gitHub {
return pkg
}

switch typedPkg := pkg.(type) {
case github.GitHubPackage:
typedPkg.Name = modelName
return typedPkg
case *github.GitHubPackage:
clonedPkg := *typedPkg
clonedPkg.Name = modelName
return &clonedPkg
default:
return pkg
}
}

func (pf *packageFetcher) getPackage(registrant, sourceURL, modelName string) (models.Package, error) {
cacheKey := packageCacheKey(registrant, sourceURL, modelName)
if cachedPkg, ok := pf.cache.Load(cacheKey); ok {
return packageForModel(registrant, modelName, cachedPkg.(models.Package)), nil
}

fetchedPkg, err, _ := pf.fetchGroup.Do(cacheKey, func() (interface{}, error) {
generator, err := pf.newGenerator(registrant, sourceURL, modelName)
if err != nil {
return nil, err
}

if utils.ReplaceSpacesAndConvertToLowercase(registrant) == artifactHub {
RateLimitArtifactHub()
}

pkg, err := generator.GetPackage()
if err != nil {
return nil, err
}

pf.cache.Store(cacheKey, pkg)
return pkg, nil
})
if err != nil {
return nil, err
}

return packageForModel(registrant, modelName, fetchedPkg.(models.Package)), nil
}

type compGenerateTracker struct {
totalComps int
version string
Expand Down Expand Up @@ -800,6 +876,7 @@ func InvokeGenerationFromSheet(wg *sync.WaitGroup, path string, modelsheetID, co
// - Latest version only filtering
func InvokeGenerationFromSheetWithOptions(wg *sync.WaitGroup, path string, modelsheetID, componentSheetID int64, spreadsheeetID string, modelName string, modelCSVFilePath, componentCSVFilePath, spreadsheeetCred, relationshipCSVFilePath string, relationshipSheetID int64, srv *sheets.Service, opts GenerationOptions) error {
weightedSem := semaphore.NewWeighted(20)
packageFetcher := newPackageFetcher(generators.NewGenerator)
url := GoogleSpreadSheetURL + spreadsheeetID
totalAvailableModels := 0
spreadsheeetChan := make(chan SpreadsheetData)
Expand Down Expand Up @@ -924,19 +1001,8 @@ func InvokeGenerationFromSheetWithOptions(wg *sync.WaitGroup, path string, model
}

Log.Debug(fmt.Sprintf("Model %s: Creating generator for registrant: %s, source: %s", model.Model, model.Registrant, model.SourceURL))

generator, genErr := generators.NewGenerator(model.Registrant, model.SourceURL, model.Model)
if genErr != nil {
done <- ErrGenerateModel(genErr, model.Model)
return
}

if utils.ReplaceSpacesAndConvertToLowercase(model.Registrant) == "artifacthub" {
RateLimitArtifactHub()
}

Log.Debug(fmt.Sprintf("Model %s: Fetching package from source", model.Model))
pkg, genErr := generator.GetPackage()
pkg, genErr := packageFetcher.getPackage(model.Registrant, model.SourceURL, model.Model)
if genErr != nil {
done <- ErrGenerateModel(genErr, model.Model)
return
Expand Down
131 changes: 131 additions & 0 deletions registry/model_generation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,46 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"

artifacthubgen "github.com/meshery/meshkit/generators/artifacthub"
githubgen "github.com/meshery/meshkit/generators/github"
"github.com/meshery/meshkit/generators/models"
"github.com/stretchr/testify/assert"
)

type stubPackageManager struct {
pkg models.Package
callCount *atomic.Int32
delay time.Duration
}

func (spm stubPackageManager) GetPackage() (models.Package, error) {
spm.callCount.Add(1)
if spm.delay > 0 {
time.Sleep(spm.delay)
}
return spm.pkg, nil
}

func stubPackageForRegistrant(registrant, url, packageName string) models.Package {
switch registrant {
case "artifacthub":
return artifacthubgen.AhPackage{
Name: fmt.Sprintf("%s:%s", registrant, packageName),
ChartUrl: url,
Version: "v1.0.0",
}
default:
return githubgen.GitHubPackage{
Name: packageName,
SourceURL: url,
}
}
}

func TestGenerationOptionsTimeoutBehavior(t *testing.T) {
// Test that timeout value is respected when set
tests := []struct {
Expand Down Expand Up @@ -44,6 +78,103 @@ func TestGenerationOptionsTimeoutBehavior(t *testing.T) {
}
}

func TestPackageFetcherCachesGitHubPackagesByRegistrantAndSourceURL(t *testing.T) {
t.Parallel()

callCount := &atomic.Int32{}
fetcher := newPackageFetcher(func(registrant, url, packageName string) (models.PackageManager, error) {
return stubPackageManager{
pkg: stubPackageForRegistrant(registrant, url, packageName),
callCount: callCount,
}, nil
})

firstPkg, err := fetcher.getPackage("github", "https://example.com/aso.yaml", "azure-network")
assert.NoError(t, err)

secondPkg, err := fetcher.getPackage("github", "https://example.com/aso.yaml", "azure-compute")
assert.NoError(t, err)

assert.EqualValues(t, 1, callCount.Load())
assert.Equal(t, "azure-network", firstPkg.GetName())
assert.Equal(t, "azure-compute", secondPkg.GetName())
}

func TestPackageFetcherDoesNotShareArtifactHubPackagesAcrossModelNames(t *testing.T) {
t.Parallel()

callCount := &atomic.Int32{}
fetcher := newPackageFetcher(func(registrant, url, packageName string) (models.PackageManager, error) {
return stubPackageManager{
pkg: stubPackageForRegistrant(registrant, url, packageName),
callCount: callCount,
}, nil
})

_, err := fetcher.getPackage("artifacthub", "https://example.com/shared.yaml", "azure-network")
assert.NoError(t, err)

_, err = fetcher.getPackage("artifacthub", "https://example.com/shared.yaml", "azure-compute")
assert.NoError(t, err)

assert.EqualValues(t, 2, callCount.Load())
}

func TestPackageFetcherDoesNotShareAcrossRegistrants(t *testing.T) {
t.Parallel()

callCount := &atomic.Int32{}
fetcher := newPackageFetcher(func(registrant, url, packageName string) (models.PackageManager, error) {
return stubPackageManager{
pkg: stubPackageForRegistrant(registrant, url, packageName),
callCount: callCount,
}, nil
})

_, err := fetcher.getPackage("github", "https://example.com/shared.yaml", "azure-network")
assert.NoError(t, err)

_, err = fetcher.getPackage("artifacthub", "https://example.com/shared.yaml", "azure-network")
assert.NoError(t, err)

assert.EqualValues(t, 2, callCount.Load())
}

func TestPackageFetcherDeduplicatesConcurrentGitHubRequests(t *testing.T) {
t.Parallel()

callCount := &atomic.Int32{}
fetcher := newPackageFetcher(func(registrant, url, packageName string) (models.PackageManager, error) {
return stubPackageManager{
pkg: stubPackageForRegistrant(registrant, url, packageName),
callCount: callCount,
delay: 25 * time.Millisecond,
}, nil
})

modelNames := []string{
"azure-network",
"azure-compute",
"azure-storage",
"azure-network",
"azure-compute",
"azure-storage",
}
var wg sync.WaitGroup
for _, modelName := range modelNames {
wg.Add(1)
go func(modelName string) {
defer wg.Done()
pkg, err := fetcher.getPackage("github", "https://example.com/aso.yaml", modelName)
assert.NoError(t, err)
assert.Equal(t, modelName, pkg.GetName())
}(modelName)
}
wg.Wait()

assert.EqualValues(t, 1, callCount.Load())
}

func TestProgressTrackerIntegration(t *testing.T) {
// Simulate a model generation workflow
totalModels := 50
Expand Down