diff --git a/internal/workflow_registry.go b/internal/workflow_registry.go new file mode 100644 index 0000000..3b5c67c --- /dev/null +++ b/internal/workflow_registry.go @@ -0,0 +1,100 @@ +package internal + +import ( + "context" + "fmt" + "sync" + "time" + + "homelab.lan/music-agregator/internal/eventbus" +) + +type WorkflowEntry struct { + ID string + AlbumID string + Quality string + Ctx context.Context + Cancel context.CancelFunc + Topic string + Ready chan struct{} +} + +type WorkflowRegistry struct { + mu sync.Mutex + workflows map[string]*WorkflowEntry + bus *eventbus.EventBus + wg sync.WaitGroup +} + +func NewWorkflowRegistry(bus *eventbus.EventBus) *WorkflowRegistry { + return &WorkflowRegistry{ + workflows: make(map[string]*WorkflowEntry), + bus: bus, + } +} + +func workflowKey(albumID, quality string) string { + return fmt.Sprintf("%s:%s", albumID, quality) +} + +func (r *WorkflowRegistry) GetOrCreate(ctx context.Context, albumID, quality string) (*WorkflowEntry, bool) { + r.mu.Lock() + defer r.mu.Unlock() + + key := workflowKey(albumID, quality) + if entry, ok := r.workflows[key]; ok { + return entry, false + } + + wfCtx, cancel := context.WithCancel(ctx) + entry := &WorkflowEntry{ + AlbumID: albumID, + Quality: quality, + Ctx: wfCtx, + Cancel: cancel, + Topic: key, + Ready: make(chan struct{}), + } + r.workflows[key] = entry + return entry, true +} + +func (r *WorkflowRegistry) Remove(albumID, quality string) { + r.mu.Lock() + defer r.mu.Unlock() + + key := workflowKey(albumID, quality) + delete(r.workflows, key) +} + +func (r *WorkflowRegistry) Get(albumID, quality string) (*WorkflowEntry, bool) { + r.mu.Lock() + defer r.mu.Unlock() + + key := workflowKey(albumID, quality) + entry, ok := r.workflows[key] + return entry, ok +} + +func (r *WorkflowRegistry) WaitGroup() *sync.WaitGroup { + return &r.wg +} + +func (r *WorkflowRegistry) Shutdown(timeout time.Duration) { + r.mu.Lock() + for _, entry := range r.workflows { + entry.Cancel() + } + r.mu.Unlock() + + done := make(chan struct{}) + go func() { + r.wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(timeout): + } +} diff --git a/internal/workflow_registry_test.go b/internal/workflow_registry_test.go new file mode 100644 index 0000000..fcebf07 --- /dev/null +++ b/internal/workflow_registry_test.go @@ -0,0 +1,144 @@ +package internal + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "homelab.lan/music-agregator/internal/eventbus" +) + +func TestRegistry_GetOrCreate_New(t *testing.T) { + bus := eventbus.New() + reg := NewWorkflowRegistry(bus) + + entry, created := reg.GetOrCreate(context.Background(), "album-1", "LOSSLESS") + assert.True(t, created) + require.NotNil(t, entry) + assert.Equal(t, "album-1", entry.AlbumID) + assert.Equal(t, "LOSSLESS", entry.Quality) + assert.Equal(t, "album-1:LOSSLESS", entry.Topic) + assert.NotNil(t, entry.Ctx) + assert.NotNil(t, entry.Cancel) + assert.NotNil(t, entry.Ready) +} + +func TestRegistry_GetOrCreate_ExistingReturned(t *testing.T) { + bus := eventbus.New() + reg := NewWorkflowRegistry(bus) + + entry1, created1 := reg.GetOrCreate(context.Background(), "album-1", "LOSSLESS") + assert.True(t, created1) + + entry2, created2 := reg.GetOrCreate(context.Background(), "album-1", "LOSSLESS") + assert.False(t, created2) + assert.Same(t, entry1, entry2) +} + +func TestRegistry_GetOrCreate_DifferentQuality(t *testing.T) { + bus := eventbus.New() + reg := NewWorkflowRegistry(bus) + + entry1, created1 := reg.GetOrCreate(context.Background(), "album-1", "LOSSLESS") + assert.True(t, created1) + + entry2, created2 := reg.GetOrCreate(context.Background(), "album-1", "LOSSY") + assert.True(t, created2) + assert.NotSame(t, entry1, entry2) +} + +func TestRegistry_Remove(t *testing.T) { + bus := eventbus.New() + reg := NewWorkflowRegistry(bus) + + reg.GetOrCreate(context.Background(), "album-1", "LOSSLESS") + reg.Remove("album-1", "LOSSLESS") + + _, ok := reg.Get("album-1", "LOSSLESS") + assert.False(t, ok) + + entry, created := reg.GetOrCreate(context.Background(), "album-1", "LOSSLESS") + assert.True(t, created) + assert.NotNil(t, entry) +} + +func TestRegistry_Get(t *testing.T) { + bus := eventbus.New() + reg := NewWorkflowRegistry(bus) + + _, ok := reg.Get("album-1", "LOSSLESS") + assert.False(t, ok) + + reg.GetOrCreate(context.Background(), "album-1", "LOSSLESS") + entry, ok := reg.Get("album-1", "LOSSLESS") + assert.True(t, ok) + assert.Equal(t, "album-1", entry.AlbumID) +} + +func TestRegistry_ConcurrentGetOrCreate(t *testing.T) { + bus := eventbus.New() + reg := NewWorkflowRegistry(bus) + + var wg sync.WaitGroup + results := make(chan bool, 20) + + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, created := reg.GetOrCreate(context.Background(), "album-1", "LOSSLESS") + results <- created + }() + } + + wg.Wait() + close(results) + + createdCount := 0 + for created := range results { + if created { + createdCount++ + } + } + assert.Equal(t, 1, createdCount) +} + +func TestRegistry_Shutdown(t *testing.T) { + bus := eventbus.New() + reg := NewWorkflowRegistry(bus) + + entry, _ := reg.GetOrCreate(context.Background(), "album-1", "LOSSLESS") + + reg.WaitGroup().Add(1) + go func() { + defer reg.WaitGroup().Done() + <-entry.Ctx.Done() + }() + + start := time.Now() + reg.Shutdown(5 * time.Second) + elapsed := time.Since(start) + + assert.Less(t, elapsed, 2*time.Second) + assert.Error(t, entry.Ctx.Err()) +} + +func TestRegistry_ShutdownTimeout(t *testing.T) { + bus := eventbus.New() + reg := NewWorkflowRegistry(bus) + + reg.WaitGroup().Add(1) + + start := time.Now() + reg.Shutdown(100 * time.Millisecond) + elapsed := time.Since(start) + + assert.GreaterOrEqual(t, elapsed, 100*time.Millisecond) + assert.Less(t, elapsed, 500*time.Millisecond) + + reg.WaitGroup().Done() +}