diff --git a/TODO.md b/TODO.md index 4c2d16b..f540068 100644 --- a/TODO.md +++ b/TODO.md @@ -23,7 +23,7 @@ - [ENHANCEMENT] Change how frequently the course availability check is run - [ENHANCEMENT] Support for FFMPEG path - [ENHANCEMENT] On mobile use a drawer for tags -- [ENHANCEMENT] Write a general scanning monitor util +- [ENHANCEMENT] Write a general course scanner - Add 1 or more scans, do a bulk query for all in the list - take a writable and update the status @@ -33,7 +33,7 @@ - [ENHANCEMENT] Support adding categories from on the home page - [ENHANCEMENT] Fix the difference in location of the loading icon and the error - [ENHANCEMENT] Change from carousel to no carousel -- +- [ENHANCEMENT] Add completed and updated icons on course cards ### Courses @@ -113,7 +113,6 @@ ### Tags -- [ENHANCEMENT] Currently uppercase and lowercase tags are different and so uppercase are ordered first. Make them case insensitive - [ENHANCEMENT] Analyze and optimize the DB ### Assets and Attachments diff --git a/api/assets.go b/api/assets.go index 0f9a3ec..aef4fa4 100644 --- a/api/assets.go +++ b/api/assets.go @@ -8,7 +8,6 @@ import ( "path/filepath" "strconv" "strings" - "time" "github.com/geerew/off-course/daos" "github.com/geerew/off-course/database" @@ -34,20 +33,20 @@ type assets struct { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ type assetResponse struct { - ID string `json:"id"` - CourseID string `json:"courseId"` - Title string `json:"title"` - Prefix int `json:"prefix"` - Chapter string `json:"chapter"` - Path string `json:"path"` - Type types.Asset `json:"assetType"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` + ID string `json:"id"` + CourseID string `json:"courseId"` + Title string `json:"title"` + Prefix int `json:"prefix"` + Chapter string `json:"chapter"` + Path string `json:"path"` + Type types.Asset `json:"assetType"` + CreatedAt types.DateTime `json:"createdAt"` + UpdatedAt types.DateTime `json:"updatedAt"` // Progress - VideoPos int `json:"videoPos"` - Completed bool `json:"completed"` - CompletedAt time.Time `json:"completedAt"` + VideoPos int `json:"videoPos"` + Completed bool `json:"completed"` + CompletedAt types.DateTime `json:"completedAt"` // Attachments Attachments []*attachmentResponse `json:"attachments,omitempty"` diff --git a/api/assets_test.go b/api/assets_test.go index d55567f..6ec3f4b 100644 --- a/api/assets_test.go +++ b/api/assets_test.go @@ -394,7 +394,7 @@ func TestAssets_ServeAsset(t *testing.T) { coursesDao := daos.NewCourseDao(router.config.DbManager.DataDb) assetsDao := daos.NewAssetDao(router.config.DbManager.DataDb) - require.Nil(t, coursesDao.Create(testData[0].Course)) + require.Nil(t, coursesDao.Create(testData[0].Course, nil)) require.Nil(t, assetsDao.Create(testData[0].Assets[0], nil)) // Create asset diff --git a/api/attachments.go b/api/attachments.go index 8a3b9ec..115339c 100644 --- a/api/attachments.go +++ b/api/attachments.go @@ -4,13 +4,13 @@ import ( "database/sql" "log/slog" "strings" - "time" "github.com/geerew/off-course/daos" "github.com/geerew/off-course/database" "github.com/geerew/off-course/models" "github.com/geerew/off-course/utils/appFs" "github.com/geerew/off-course/utils/pagination" + "github.com/geerew/off-course/utils/types" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/filesystem" "github.com/spf13/afero" @@ -27,13 +27,13 @@ type attachments struct { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ type attachmentResponse struct { - ID string `json:"id"` - AssetId string `json:"assetId"` - CourseID string `json:"courseId"` - Title string `json:"title"` - Path string `json:"path"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` + ID string `json:"id"` + AssetId string `json:"assetId"` + CourseID string `json:"courseId"` + Title string `json:"title"` + Path string `json:"path"` + CreatedAt types.DateTime `json:"createdAt"` + UpdatedAt types.DateTime `json:"updatedAt"` } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/api/courses.go b/api/courses.go index b31cdd1..2b6df4f 100644 --- a/api/courses.go +++ b/api/courses.go @@ -7,7 +7,6 @@ import ( "net/url" "os" "strings" - "time" "github.com/Masterminds/squirrel" "github.com/geerew/off-course/daos" @@ -17,6 +16,7 @@ import ( "github.com/geerew/off-course/utils/appFs" "github.com/geerew/off-course/utils/jobs" "github.com/geerew/off-course/utils/pagination" + "github.com/geerew/off-course/utils/types" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/filesystem" "github.com/spf13/afero" @@ -41,23 +41,23 @@ type courses struct { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ type courseResponse struct { - ID string `json:"id"` - Title string `json:"title"` - Path string `json:"path"` - HasCard bool `json:"hasCard"` - Available bool `json:"available"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` + ID string `json:"id"` + Title string `json:"title"` + Path string `json:"path"` + HasCard bool `json:"hasCard"` + Available bool `json:"available"` + CreatedAt types.DateTime `json:"createdAt"` + UpdatedAt types.DateTime `json:"updatedAt"` // Scan status ScanStatus string `json:"scanStatus"` // Progress - Started bool `json:"started"` - StartedAt time.Time `json:"startedAt"` - Percent int `json:"percent"` - CompletedAt time.Time `json:"completedAt"` - ProgressUpdatedAt time.Time `json:"progressUpdatedAt"` + Started bool `json:"started"` + StartedAt types.DateTime `json:"startedAt"` + Percent int `json:"percent"` + CompletedAt types.DateTime `json:"completedAt"` + ProgressUpdatedAt types.DateTime `json:"progressUpdatedAt"` } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -116,7 +116,7 @@ func (api *courses) getCourses(c *fiber.Ctx) error { return errorResponse(c, fiber.StatusBadRequest, "Invalid tags parameter", err) } - courseIds, err := api.courseTagDao.ListCourseIdsByTags(filtered, nil) + courseIds, err := api.courseTagDao.ListCourseIdsByTags(filtered, nil, nil) if err != nil { return errorResponse(c, fiber.StatusInternalServerError, "Error looking up courses by tags", err) } @@ -169,7 +169,7 @@ func (api *courses) getCourses(c *fiber.Ctx) error { func (api *courses) getCourse(c *fiber.Ctx) error { id := c.Params("id") - course, err := api.courseDao.Get(id, nil, nil) + course, err := api.courseDao.Get(id, nil) if err != nil { if err == sql.ErrNoRows { @@ -209,7 +209,7 @@ func (api *courses) createCourse(c *fiber.Ctx) error { // Set the course to available course.Available = true - if err := api.courseDao.Create(course); err != nil { + if err := api.courseDao.Create(course, nil); err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") { return errorResponse(c, fiber.StatusBadRequest, "A course with this path already exists", err) } @@ -245,7 +245,7 @@ func (api *courses) deleteCourse(c *fiber.Ctx) error { func (api *courses) getCard(c *fiber.Ctx) error { id := c.Params("id") - course, err := api.courseDao.Get(id, nil, nil) + course, err := api.courseDao.Get(id, nil) if err != nil { if err == sql.ErrNoRows { @@ -277,7 +277,7 @@ func (api *courses) getAssets(c *fiber.Ctx) error { expand := c.QueryBool("expand", false) // Get the course - _, err := api.courseDao.Get(id, nil, nil) + _, err := api.courseDao.Get(id, nil) if err != nil { if err == sql.ErrNoRows { return errorResponse(c, fiber.StatusNotFound, "Course not found", nil) @@ -316,7 +316,7 @@ func (api *courses) getAsset(c *fiber.Ctx) error { assetId := c.Params("asset") expand := c.QueryBool("expand", false) - _, err := api.courseDao.Get(id, nil, nil) + _, err := api.courseDao.Get(id, nil) if err != nil { if err == sql.ErrNoRows { return errorResponse(c, fiber.StatusNotFound, "Course not found", nil) @@ -356,7 +356,7 @@ func (api *courses) getAssetAttachments(c *fiber.Ctx) error { orderBy := c.Query("orderBy", "title asc") // Get the course - _, err := api.courseDao.Get(id, nil, nil) + _, err := api.courseDao.Get(id, nil) if err != nil { if err == sql.ErrNoRows { return errorResponse(c, fiber.StatusNotFound, "Course not found", nil) @@ -406,7 +406,7 @@ func (api *courses) getAssetAttachment(c *fiber.Ctx) error { attachmentId := c.Params("attachment") // Get the course - _, err := api.courseDao.Get(id, nil, nil) + _, err := api.courseDao.Get(id, nil) if err != nil { if err == sql.ErrNoRows { return errorResponse(c, fiber.StatusNotFound, "Course not found", nil) @@ -452,7 +452,7 @@ func (api *courses) getTags(c *fiber.Ctx) error { id := c.Params("id") // Get the course - _, err := api.courseDao.Get(id, nil, nil) + _, err := api.courseDao.Get(id, nil) if err != nil { if err == sql.ErrNoRows { return errorResponse(c, fiber.StatusNotFound, "Course not found", nil) diff --git a/api/courses_test.go b/api/courses_test.go index 4242002..b53bb8f 100644 --- a/api/courses_test.go +++ b/api/courses_test.go @@ -642,7 +642,7 @@ func TestCourses_DeleteCourse(t *testing.T) { require.NoError(t, err) require.Equal(t, http.StatusNoContent, status) - _, err = courseDao.Get(testData[2].ID, nil, nil) + _, err = courseDao.Get(testData[2].ID, nil) require.ErrorIs(t, err, sql.ErrNoRows) // ---------------------------- diff --git a/api/logs.go b/api/logs.go index 7230032..271a424 100644 --- a/api/logs.go +++ b/api/logs.go @@ -2,7 +2,6 @@ package api import ( "log/slog" - "time" "github.com/Masterminds/squirrel" "github.com/geerew/off-course/daos" @@ -23,11 +22,11 @@ type logs struct { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ type logResponse struct { - ID string `json:"id"` - Level int `json:"level"` - Message string `json:"message"` - Data types.JsonMap `json:"data"` - CreatedAt time.Time `json:"createdAt"` + ID string `json:"id"` + Level int `json:"level"` + Message string `json:"message"` + Data types.JsonMap `json:"data"` + CreatedAt types.DateTime `json:"createdAt"` } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/api/scans.go b/api/scans.go index dd1ac2a..5f1abf1 100644 --- a/api/scans.go +++ b/api/scans.go @@ -3,7 +3,6 @@ package api import ( "database/sql" "log/slog" - "time" "github.com/geerew/off-course/daos" "github.com/geerew/off-course/models" @@ -28,8 +27,8 @@ type scanResponse struct { ID string `json:"id"` CourseID string `json:"courseId"` Status types.ScanStatus `json:"status"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` + CreatedAt types.DateTime `json:"createdAt"` + UpdatedAt types.DateTime `json:"updatedAt"` } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/api/tags.go b/api/tags.go index a0c1312..6403726 100644 --- a/api/tags.go +++ b/api/tags.go @@ -6,13 +6,13 @@ import ( "net/url" "sort" "strings" - "time" "github.com/Masterminds/squirrel" "github.com/geerew/off-course/daos" "github.com/geerew/off-course/database" "github.com/geerew/off-course/models" "github.com/geerew/off-course/utils/pagination" + "github.com/geerew/off-course/utils/types" "github.com/gofiber/fiber/v2" ) @@ -27,12 +27,12 @@ type tags struct { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ type tagResponse struct { - ID string `json:"id"` - Tag string `json:"tag"` - CourseCount int `json:"courseCount"` - Courses []*courseTag `json:"courses,omitempty"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` + ID string `json:"id"` + Tag string `json:"tag"` + CourseCount int `json:"courseCount"` + Courses []*courseTag `json:"courses,omitempty"` + CreatedAt types.DateTime `json:"createdAt"` + UpdatedAt types.DateTime `json:"updatedAt"` } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/cron/base.go b/cron/base.go index 490dc16..c579b3b 100644 --- a/cron/base.go +++ b/cron/base.go @@ -3,6 +3,7 @@ package cron import ( "log/slog" + "github.com/geerew/off-course/dao" "github.com/geerew/off-course/database" "github.com/geerew/off-course/utils/appFs" "github.com/geerew/off-course/utils/types" @@ -31,6 +32,7 @@ func InitCron(config *CronConfig) { // Course availability ca := &courseAvailability{ db: config.Db, + dao: dao.NewDAO(config.Db), appFs: config.AppFs, logger: config.Logger, batchSize: 200, diff --git a/cron/common_test.go b/cron/common_test.go index 581a5ed..d949810 100644 --- a/cron/common_test.go +++ b/cron/common_test.go @@ -14,7 +14,7 @@ import ( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -func setup(t *testing.T) (*database.DatabaseManager, *appFs.AppFs, *slog.Logger, *[]*logger.Log) { +func setup(t *testing.T) (database.Database, *appFs.AppFs, *slog.Logger, *[]*logger.Log) { t.Helper() // Logger @@ -37,9 +37,9 @@ func setup(t *testing.T) (*database.DatabaseManager, *appFs.AppFs, *slog.Logger, InMemory: true, }) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, dbManager) // teardown - return dbManager, appFs, logger, &logs + return dbManager.DataDb, appFs, logger, &logs } diff --git a/cron/course_availability.go b/cron/course_availability.go index 12a35b1..e6aa629 100644 --- a/cron/course_availability.go +++ b/cron/course_availability.go @@ -1,10 +1,11 @@ package cron import ( + "context" "log/slog" "os" - "github.com/geerew/off-course/daos" + "github.com/geerew/off-course/dao" "github.com/geerew/off-course/database" "github.com/geerew/off-course/models" "github.com/geerew/off-course/utils/appFs" @@ -13,6 +14,7 @@ import ( type courseAvailability struct { db database.Database + dao *dao.DAO appFs *appFs.AppFs logger *slog.Logger batchSize int @@ -28,14 +30,15 @@ func (ca *courseAvailability) run() error { coursesBatch := make([]*models.Course, 0, ca.batchSize) - courseDao := daos.NewCourseDao(ca.db) + ctx := context.Background() for page <= totalPages { p := pagination.New(page, perPage) - paginationParams := &database.DatabaseParams{Pagination: p} + options := &database.Options{Pagination: p} // Fetch a batch of courses - courses, err := courseDao.List(paginationParams, nil) + courses := []*models.Course{} + err := ca.dao.List(ctx, &courses, options) if err != nil { attrs := []any{ loggerType, @@ -82,7 +85,7 @@ func (ca *courseAvailability) run() error { // Update the courses if we hit the batch size if len(coursesBatch) == ca.batchSize { - ca.writeAll(coursesBatch) + ca.writeAll(ctx, coursesBatch) coursesBatch = coursesBatch[:0] } } @@ -92,7 +95,7 @@ func (ca *courseAvailability) run() error { // Update any remaining courses if len(coursesBatch) > 0 { - ca.writeAll(coursesBatch) + ca.writeAll(ctx, coursesBatch) } return nil @@ -100,13 +103,11 @@ func (ca *courseAvailability) run() error { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -func (ca *courseAvailability) writeAll(courses []*models.Course) { - courseDao := daos.NewCourseDao(ca.db) - +func (ca *courseAvailability) writeAll(ctx context.Context, courses []*models.Course) { // Update the courses in a transaction - err := ca.db.RunInTransaction(func(tx *database.Tx) error { + err := ca.db.RunInTransaction(ctx, func(txCtx context.Context) error { for _, course := range courses { - err := courseDao.Update(course, tx) + err := ca.dao.UpdateCourse(txCtx, course) if err != nil { return err } diff --git a/cron/course_availability_test.go b/cron/course_availability_test.go index c9143a3..ea1f6c8 100644 --- a/cron/course_availability_test.go +++ b/cron/course_availability_test.go @@ -1,10 +1,12 @@ package cron import ( + "context" "fmt" "testing" - "github.com/geerew/off-course/daos" + "github.com/geerew/off-course/dao" + "github.com/geerew/off-course/models" "github.com/geerew/off-course/utils/appFs" "github.com/geerew/off-course/utils/mocks" "github.com/spf13/afero" @@ -14,117 +16,62 @@ import ( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ func TestCourseAvailability_Run(t *testing.T) { - t.Run("single update", func(t *testing.T) { - dbManager, appFs, logger, _ := setup(t) + t.Run("update", func(t *testing.T) { + db, appFs, logger, _ := setup(t) - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Build() + dao := dao.NewDAO(db) + ctx := context.Background() - // ---------------------------- - // true -> false - // ---------------------------- - - // Mark the course as available - testData[0].Course.Available = true - require.Nil(t, daos.NewCourseDao(dbManager.DataDb).Update(testData[0].Course, nil)) - - ca := &courseAvailability{ - db: dbManager.DataDb, - appFs: appFs, - logger: logger, - batchSize: 1, - } - - err := ca.run() - require.Nil(t, err) - - // Check the course is marked as unavailable - course, err := daos.NewCourseDao(dbManager.DataDb).Get(testData[0].Course.ID, nil, nil) - require.Nil(t, err) - require.False(t, course.Available) - - // ---------------------------- - // false -> true - // ---------------------------- - - // Create course directory - require.Nil(t, appFs.Fs.MkdirAll(course.Path, 0755)) - - err = ca.run() - require.Nil(t, err) - - // Check the course is marked as available - course, err = daos.NewCourseDao(dbManager.DataDb).Get(testData[0].Course.ID, nil, nil) - require.Nil(t, err) - require.True(t, course.Available) - }) - - t.Run("multi update", func(t *testing.T) { - dbManager, appFs, logger, _ := setup(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(3).Build() - - // ---------------------------- - // true -> false - // ---------------------------- - - // Mark the courses as available - for _, data := range testData { - data.Course.Available = true - require.Nil(t, daos.NewCourseDao(dbManager.DataDb).Update(data.Course, nil)) + courses := []*models.Course{} + for i := range 3 { + course := &models.Course{Title: fmt.Sprintf("course %d", i), Path: fmt.Sprintf("/course-%d", i), Available: false} + require.NoError(t, dao.CreateCourse(ctx, course)) + courses = append(courses, course) } ca := &courseAvailability{ - db: dbManager.DataDb, + db: db, + dao: dao, appFs: appFs, logger: logger, batchSize: 2, } err := ca.run() - require.Nil(t, err) + require.NoError(t, err) - // Check the courses are marked as unavailable - for _, data := range testData { - course, err := daos.NewCourseDao(dbManager.DataDb).Get(data.Course.ID, nil, nil) - require.Nil(t, err) - require.False(t, course.Available) - } - - // ---------------------------- - // false -> true - // ---------------------------- - - // Create course directories - for _, data := range testData { - require.Nil(t, appFs.Fs.MkdirAll(data.Course.Path, 0755)) + for _, course := range courses { + require.Nil(t, appFs.Fs.MkdirAll(course.Path, 0755)) } err = ca.run() - require.Nil(t, err) + require.NoError(t, err) - // Check the courses are marked as available - for _, data := range testData { - course, err := daos.NewCourseDao(dbManager.DataDb).Get(data.Course.ID, nil, nil) - require.Nil(t, err) + for _, course := range courses { + err := dao.GetById(ctx, course) + require.NoError(t, err) require.True(t, course.Available) } }) t.Run("stat error", func(t *testing.T) { - dbManager, _, logger, logs := setup(t) + db, _, logger, logs := setup(t) + + dao := dao.NewDAO(db) + ctx := context.Background() - daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Build() + course := &models.Course{Title: "course 1", Path: "/course-1", Available: false} + require.NoError(t, dao.CreateCourse(ctx, course)) fsWithError := &mocks.MockFsWithError{ Fs: afero.NewMemMapFs(), ErrToReturn: fmt.Errorf("stat error"), } - caAppFs := appFs.NewAppFs(fsWithError, logger) - ca := &courseAvailability{ - db: dbManager.DataDb, - appFs: caAppFs, + db: db, + dao: dao, + appFs: appFs.NewAppFs(fsWithError, logger), logger: logger, batchSize: 1, } @@ -135,24 +82,24 @@ func TestCourseAvailability_Run(t *testing.T) { // Check the logger require.Len(t, *logs, 2) require.Equal(t, "Failed to stat course", (*logs)[1].Message) - }) + t.Run("db error", func(t *testing.T) { - dbManager, appFs, logger, logs := setup(t) + db, appFs, logger, logs := setup(t) - // Drop the table - _, err := dbManager.DataDb.Exec("DROP TABLE IF EXISTS " + daos.NewCourseDao(dbManager.DataDb).Table()) - require.Nil(t, err) + _, err := db.Exec("DROP TABLE IF EXISTS " + models.COURSE_TABLE) + require.NoError(t, err) ca := &courseAvailability{ - db: dbManager.DataDb, + db: db, + dao: dao.NewDAO(db), appFs: appFs, logger: logger, batchSize: 1, } err = ca.run() - require.ErrorContains(t, err, "no such table: "+daos.NewCourseDao(dbManager.DataDb).Table()) + require.ErrorContains(t, err, "no such table: "+models.COURSE_TABLE) // Check the logger require.Len(t, *logs, 2) diff --git a/dao/asset.go b/dao/asset.go new file mode 100644 index 0000000..2228503 --- /dev/null +++ b/dao/asset.go @@ -0,0 +1,38 @@ +package dao + +import ( + "context" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CreateAsset creates an asset and refreshes course progress +func (dao *DAO) CreateAsset(ctx context.Context, asset *models.Asset) error { + if asset == nil { + return utils.ErrNilPtr + } + + return dao.db.RunInTransaction(ctx, func(txCtx context.Context) error { + err := dao.Create(txCtx, asset) + if err != nil { + return err + } + + return dao.RefreshCourseProgress(txCtx, asset.CourseID) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// UpdateAsset updates an asset +func (dao *DAO) UpdateAsset(ctx context.Context, asset *models.Asset) error { + if asset == nil { + return utils.ErrNilPtr + } + + _, err := dao.Update(ctx, asset) + return err +} diff --git a/dao/asset_progress.go b/dao/asset_progress.go new file mode 100644 index 0000000..c1cd64b --- /dev/null +++ b/dao/asset_progress.go @@ -0,0 +1,107 @@ +package dao + +import ( + "context" + "database/sql" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/types" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CreateOrUpdateAssetProgress creates/updates an asset progress and refreshes course progress +func (dao *DAO) CreateOrUpdateAssetProgress(ctx context.Context, assetProgress *models.AssetProgress) error { + if assetProgress == nil { + return utils.ErrNilPtr + } + + return dao.db.RunInTransaction(ctx, func(txCtx context.Context) error { + if assetProgress.VideoPos < 0 { + assetProgress.VideoPos = 0 + } + + // Check for existing asset progress + existingAP := &models.AssetProgress{} + err := dao.Get( + txCtx, + existingAP, + &database.Options{Where: squirrel.Eq{models.ASSET_PROGRESS_TABLE + ".asset_id": assetProgress.AssetID}}, + ) + + if err != nil && err != sql.ErrNoRows { + return err + } + + // Create + if err == sql.ErrNoRows { + if assetProgress.Completed { + assetProgress.CompletedAt = types.NowDateTime() + } + + err := dao.Create(txCtx, assetProgress) + if err != nil { + return err + } + } else { + + // Update + if assetProgress.Completed { + if existingAP.Completed { + assetProgress.CompletedAt = existingAP.CompletedAt + } else { + assetProgress.CompletedAt = types.NowDateTime() + } + } else { + assetProgress.CompletedAt = types.DateTime{} + } + + _, err = dao.Update(txCtx, assetProgress) + if err != nil { + return err + } + } + + // Refresh course progress + return dao.RefreshCourseProgress(txCtx, assetProgress.CourseID) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Update updates select columns of an asset progress, then refresh course progress +// +// - `video_pos` (for video assets) +// - `completed` (for all assets) +// +// When `completed` is true, `completed_at` is set to the current time, else it will be null +// +// A new transaction is created if `tx` is nil +// func (dao *AssetProgressDao) Update(ap *models.AssetProgress, tx *database.Tx) error { + +// + +// // Update (or create if it doesn't exist) +// query, args, _ := squirrel. +// StatementBuilder. +// Insert(dao.Table()). +// SetMap(modelToMapOrPanic(ap)). +// Suffix( +// "ON CONFLICT (asset_id) DO UPDATE SET video_pos = ?, completed = ?, completed_at = ?, updated_at = ?", +// ap.VideoPos, ap.Completed, ap.CompletedAt, ap.UpdatedAt, +// ). +// ToSql() + +// _, err = tx.Exec(query, args...) +// if err != nil { +// return err +// } + +// // Refresh course progress +// cpDao := NewCourseProgressDao(dao.db) +// return cpDao.Refresh(ap.CourseID, tx) +// } +// } diff --git a/dao/asset_progress_test.go b/dao/asset_progress_test.go new file mode 100644 index 0000000..b268cbc --- /dev/null +++ b/dao/asset_progress_test.go @@ -0,0 +1,76 @@ +package dao + +import ( + "database/sql" + "testing" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/types" + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_CreateOrUpdateAssetProgress(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + require.NoError(t, dao.CreateAsset(ctx, asset)) + + assetProgress := &models.AssetProgress{ + AssetID: asset.ID, + CourseID: course.ID, + } + require.NoError(t, dao.CreateOrUpdateAssetProgress(ctx, assetProgress)) + }) + + t.Run("nil", func(t *testing.T) { + dao, ctx := setup(t) + require.ErrorIs(t, dao.CreateOrUpdateAssetProgress(ctx, nil), utils.ErrNilPtr) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_AssetProgressDeleteCascade(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course", Path: "/course"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + require.NoError(t, dao.CreateAsset(ctx, asset)) + + assetProgress := &models.AssetProgress{ + AssetID: asset.ID, + CourseID: course.ID, + } + require.NoError(t, dao.CreateOrUpdateAssetProgress(ctx, assetProgress)) + + require.NoError(t, dao.Delete(ctx, asset, nil)) + + err := dao.Get(ctx, assetProgress, &database.Options{Where: squirrel.Eq{assetProgress.Table() + ".id": assetProgress.ID}}) + require.ErrorIs(t, err, sql.ErrNoRows) +} diff --git a/dao/asset_test.go b/dao/asset_test.go new file mode 100644 index 0000000..b317e43 --- /dev/null +++ b/dao/asset_test.go @@ -0,0 +1,136 @@ +package dao + +import ( + "database/sql" + "testing" + "time" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/types" + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_CreateAsset(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + require.NoError(t, dao.CreateAsset(ctx, asset)) + }) + + t.Run("nil", func(t *testing.T) { + dao, ctx := setup(t) + require.ErrorIs(t, dao.CreateAsset(ctx, nil), utils.ErrNilPtr) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_UpdateAsset(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + originalAsset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + require.NoError(t, dao.CreateAsset(ctx, originalAsset)) + + time.Sleep(1 * time.Millisecond) + + newAsset := &models.Asset{ + Base: originalAsset.Base, + Title: "Asset 2", // Mutable + Prefix: sql.NullInt16{Int16: 2, Valid: true}, // Mutable + Chapter: "Chapter 2", // Mutable + Type: *types.NewAsset("html"), // Mutable + Path: "/course-1/02 asset.html", // Mutable + Hash: "5678", // Mutable + } + require.NoError(t, dao.UpdateAsset(ctx, newAsset)) + + assertResult := &models.Asset{Base: models.Base{ID: originalAsset.ID}} + require.NoError(t, dao.GetById(ctx, assertResult)) + require.Equal(t, newAsset.ID, assertResult.ID) // No change + require.True(t, newAsset.CreatedAt.Equal(originalAsset.CreatedAt)) // No change + require.Equal(t, newAsset.Title, assertResult.Title) // Changed + require.Equal(t, newAsset.Prefix, assertResult.Prefix) // Changed + require.Equal(t, newAsset.Chapter, assertResult.Chapter) // Changed + require.Equal(t, newAsset.Type, assertResult.Type) // Changed + require.Equal(t, newAsset.Path, assertResult.Path) // Changed + require.Equal(t, newAsset.Hash, assertResult.Hash) // Changed + require.False(t, assertResult.UpdatedAt.Equal(originalAsset.UpdatedAt)) // Changed + }) + + t.Run("invalid", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + require.NoError(t, dao.CreateAsset(ctx, asset)) + + // Empty ID + asset.ID = "" + require.ErrorIs(t, dao.UpdateAsset(ctx, asset), utils.ErrInvalidId) + + // Nil Model + require.ErrorIs(t, dao.UpdateAsset(ctx, nil), utils.ErrNilPtr) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_AssetDeleteCascade(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course", Path: "/course"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + require.NoError(t, dao.CreateAsset(ctx, asset)) + + require.Nil(t, dao.Delete(ctx, course, nil)) + + err := dao.GetById(ctx, asset) + require.ErrorIs(t, err, sql.ErrNoRows) +} diff --git a/dao/attachment.go b/dao/attachment.go new file mode 100644 index 0000000..f706a95 --- /dev/null +++ b/dao/attachment.go @@ -0,0 +1,31 @@ +package dao + +import ( + "context" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CreateAttachment creates an attachment +func (dao *DAO) CreateAttachment(ctx context.Context, attachment *models.Attachment) error { + if attachment == nil { + return utils.ErrNilPtr + } + + return dao.Create(ctx, attachment) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// UpdateAttachment updates an attachment +func (dao *DAO) UpdateAttachment(ctx context.Context, attachment *models.Attachment) error { + if attachment == nil { + return utils.ErrNilPtr + } + + _, err := dao.Update(ctx, attachment) + return err +} diff --git a/dao/attachments_test.go b/dao/attachments_test.go new file mode 100644 index 0000000..579727c --- /dev/null +++ b/dao/attachments_test.go @@ -0,0 +1,293 @@ +package dao + +import ( + "database/sql" + "testing" + "time" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/types" + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_CreateAttachment(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + + require.NoError(t, dao.CreateAsset(ctx, asset)) + + attachment := &models.Attachment{ + AssetID: asset.ID, + CourseID: course.ID, + Title: "Attachment 1", + Path: "/course-1/01 attachment.txt", + } + require.NoError(t, dao.CreateAttachment(ctx, attachment)) + }) + + t.Run("nil", func(t *testing.T) { + dao, ctx := setup(t) + require.ErrorIs(t, dao.CreateAttachment(ctx, nil), utils.ErrNilPtr) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_UpdateAttachment(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + + require.NoError(t, dao.CreateAsset(ctx, asset)) + + originalAttachment := &models.Attachment{ + AssetID: asset.ID, + CourseID: course.ID, + Title: "Attachment 1", + Path: "/course-1/01 Attachment 1.txt", + } + require.NoError(t, dao.CreateAttachment(ctx, originalAttachment)) + + time.Sleep(1 * time.Millisecond) + + newAttachment := &models.Attachment{ + Base: originalAttachment.Base, + AssetID: asset.ID, // Immutable + CourseID: course.ID, // Immutable + Title: "Attachment 2", // Mutable + Path: "/course-1/01 Attachment 2.txt", // Mutable + } + require.NoError(t, dao.UpdateAttachment(ctx, newAttachment)) + + attachmentResult := &models.Attachment{Base: models.Base{ID: originalAttachment.ID}} + require.NoError(t, dao.GetById(ctx, attachmentResult)) + require.Equal(t, newAttachment.ID, attachmentResult.ID) // No change + require.Equal(t, newAttachment.AssetID, attachmentResult.AssetID) // No change + require.Equal(t, newAttachment.CourseID, attachmentResult.CourseID) // No change + require.True(t, newAttachment.CreatedAt.Equal(originalAttachment.CreatedAt)) // No change + require.Equal(t, newAttachment.Title, attachmentResult.Title) // Changed + require.Equal(t, newAttachment.Path, attachmentResult.Path) // Changed + require.False(t, attachmentResult.UpdatedAt.Equal(originalAttachment.UpdatedAt)) // Changed + }) + + t.Run("invalid", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + require.NoError(t, dao.CreateAsset(ctx, asset)) + + attachment := &models.Attachment{ + AssetID: asset.ID, + CourseID: course.ID, + Title: "Attachment 1", + Path: "/course-1/01 attachment.txt", + } + require.NoError(t, dao.CreateAttachment(ctx, attachment)) + + // Empty ID + attachment.ID = "" + require.ErrorIs(t, dao.UpdateAttachment(ctx, attachment), utils.ErrInvalidId) + + // Nil Model + require.ErrorIs(t, dao.UpdateAttachment(ctx, nil), utils.ErrNilPtr) + }) +} + +// // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// func TestAttachment_List(t *testing.T) { +// t.Run("no entries", func(t *testing.T) { +// dao, _ := attachmentSetup(t) + +// assets, err := dao.List(nil, nil) +// require.Nil(t, err) +// require.Zero(t, assets) +// }) + +// t.Run("found", func(t *testing.T) { +// dao, db := attachmentSetup(t) + +// NewTestBuilder(t).Db(db).Courses(5).Assets(1).Attachments(1).Build() + +// result, err := dao.List(nil, nil) +// require.Nil(t, err) +// require.Len(t, result, 5) + +// }) + +// t.Run("orderby", func(t *testing.T) { +// dao, db := attachmentSetup(t) + +// testData := NewTestBuilder(t).Db(db).Courses(3).Assets(1).Attachments(1).Build() + +// // ---------------------------- +// // CREATED_AT DESC +// // ---------------------------- +// result, err := dao.List(&database.DatabaseParams{OrderBy: []string{"created_at desc"}}, nil) +// require.Nil(t, err) +// require.Len(t, result, 3) +// require.Equal(t, testData[2].Assets[0].Attachments[0].ID, result[0].ID) +// require.Equal(t, testData[1].Assets[0].Attachments[0].ID, result[1].ID) +// require.Equal(t, testData[0].Assets[0].Attachments[0].ID, result[2].ID) + +// // ---------------------------- +// // CREATED_AT ASC +// // ---------------------------- +// result, err = dao.List(&database.DatabaseParams{OrderBy: []string{"created_at asc"}}, nil) +// require.Nil(t, err) +// require.Len(t, result, 3) +// require.Equal(t, testData[0].Assets[0].Attachments[0].ID, result[0].ID) +// require.Equal(t, testData[1].Assets[0].Attachments[0].ID, result[1].ID) +// require.Equal(t, testData[2].Assets[0].Attachments[0].ID, result[2].ID) +// }) + +// t.Run("where", func(t *testing.T) { +// dao, db := attachmentSetup(t) + +// testData := NewTestBuilder(t).Db(db).Courses(3).Assets(2).Attachments(2).Build() + +// // ---------------------------- +// // EQUALS ID +// // ---------------------------- +// result, err := dao.List(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".id": testData[1].Assets[1].Attachments[0].ID}}, nil) +// require.Nil(t, err) +// require.Len(t, result, 1) +// require.Equal(t, testData[1].Assets[1].Attachments[0].ID, result[0].ID) + +// // ---------------------------- +// // EQUALS ID OR ID +// // ---------------------------- +// dbParams := &database.DatabaseParams{ +// Where: squirrel.Or{ +// squirrel.Eq{dao.Table() + ".id": testData[1].Assets[1].Attachments[0].ID}, +// squirrel.Eq{dao.Table() + ".id": testData[2].Assets[0].Attachments[1].ID}, +// }, +// OrderBy: []string{"created_at asc"}, +// } + +// result, err = dao.List(dbParams, nil) +// require.Nil(t, err) +// require.Len(t, result, 2) +// require.Equal(t, testData[1].Assets[1].Attachments[0].ID, result[0].ID) +// require.Equal(t, testData[2].Assets[0].Attachments[1].ID, result[1].ID) + +// // ---------------------------- +// // ERROR +// // ---------------------------- +// result, err = dao.List(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) +// require.ErrorContains(t, err, "syntax error") +// require.Nil(t, result) +// }) + +// t.Run("pagination", func(t *testing.T) { +// dao, db := attachmentSetup(t) + +// testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Attachments(17).Build() + +// // ---------------------------- +// // Page 1 with 10 items +// // ---------------------------- +// p := pagination.New(1, 10) + +// result, err := dao.List(&database.DatabaseParams{Pagination: p}, nil) +// require.Nil(t, err) +// require.Len(t, result, 10) +// require.Equal(t, 17, p.TotalItems()) +// require.Equal(t, testData[0].Assets[0].Attachments[0].ID, result[0].ID) +// require.Equal(t, testData[0].Assets[0].Attachments[9].ID, result[9].ID) + +// // ---------------------------- +// // Page 2 with 7 items +// // ---------------------------- +// p = pagination.New(2, 10) + +// result, err = dao.List(&database.DatabaseParams{Pagination: p}, nil) +// require.Nil(t, err) +// require.Len(t, result, 7) +// require.Equal(t, 17, p.TotalItems()) +// require.Equal(t, testData[0].Assets[0].Attachments[10].ID, result[0].ID) +// require.Equal(t, testData[0].Assets[0].Attachments[16].ID, result[6].ID) +// }) + +// t.Run("db error", func(t *testing.T) { +// dao, db := attachmentSetup(t) + +// _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) +// require.Nil(t, err) + +// _, err = dao.List(nil, nil) +// require.ErrorContains(t, err, "no such table: "+dao.Table()) +// }) +// } + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_AttachmentDeleteCascade(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course", Path: "/course"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + require.NoError(t, dao.CreateAsset(ctx, asset)) + + attachment := &models.Attachment{ + AssetID: asset.ID, + CourseID: course.ID, + Title: "Attachment 1", + Path: "/course-1/01 attachment.txt", + } + require.NoError(t, dao.CreateAttachment(ctx, attachment)) + + require.Nil(t, dao.Delete(ctx, asset, nil)) + + err := dao.GetById(ctx, attachment) + require.ErrorIs(t, err, sql.ErrNoRows) +} diff --git a/dao/common.go b/dao/common.go new file mode 100644 index 0000000..2985843 --- /dev/null +++ b/dao/common.go @@ -0,0 +1,169 @@ +package dao + +import ( + "context" + "database/sql" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/schema" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// DAO is a data access object +type DAO struct { + db database.Database +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// NewDAO creates a new DAO +func NewDAO(db database.Database) *DAO { + return &DAO{db: db} +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Create is a generic function to create a model in the database +func (dao *DAO) Create(ctx context.Context, model models.Modeler) error { + sch, err := schema.Parse(model) + if err != nil { + return err + } + + if model.Id() == "" { + model.RefreshId() + } + + model.RefreshCreatedAt() + model.RefreshUpdatedAt() + + q := database.QuerierFromContext(ctx, dao.db) + _, err = sch.Insert(model, q) + return err +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Count is a generic function to count the number of rows in a table as determined by the model +func (dao *DAO) Count(ctx context.Context, model any, options *database.Options) (int, error) { + sch, err := schema.Parse(model) + if err != nil { + return 0, err + } + + q := database.QuerierFromContext(ctx, dao.db) + return sch.Count(options, q) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Get is a generic function to get a model (row) +func (dao *DAO) Get(ctx context.Context, model any, options *database.Options) error { + sch, err := schema.Parse(model) + if err != nil { + return err + } + + q := database.QuerierFromContext(ctx, dao.db) + return sch.Select(model, options, q) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// GetById is a generic function to get a model (row) based on the ID of the model +func (dao *DAO) GetById(ctx context.Context, model models.Modeler) error { + if model == nil { + return utils.ErrNilPtr + } + + if model.Id() == "" { + return utils.ErrInvalidId + } + + options := &database.Options{ + Where: squirrel.Eq{model.Table() + ".id": model.Id()}, + } + + return dao.Get(ctx, model, options) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// List is a generic function to list models (rows) +func (dao *DAO) List(ctx context.Context, model any, options *database.Options) error { + sch, err := schema.Parse(model) + if err != nil { + return err + } + + if options != nil && options.Pagination != nil { + count, err := dao.Count(ctx, model, options) + if err != nil { + return err + } + + options.Pagination.SetCount(count) + } + + q := database.QuerierFromContext(ctx, dao.db) + err = sch.Select(model, options, q) + if err != nil && err != sql.ErrNoRows { + return err + } + + return nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Update is a generic function to update a model in the database +func (dao *DAO) Update(ctx context.Context, model models.Modeler) (bool, error) { + sch, err := schema.Parse(model) + if err != nil { + return false, err + } + + if model.Id() == "" { + return false, utils.ErrInvalidId + } + + model.RefreshUpdatedAt() + + q := database.QuerierFromContext(ctx, dao.db) + res, err := sch.Update(model, &database.Options{Where: squirrel.Eq{model.Table() + ".id": model.Id()}}, q) + if err != nil { + return false, err + } + + rowCount, err := res.RowsAffected() + if err != nil { + return false, err + } + + return rowCount > 0, nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Delete is a generic function to delete a model (row) +// +// If options is nil or options.Where is nil, the function will delete the model based on the ID +// of the model +func (dao *DAO) Delete(ctx context.Context, model models.Modeler, options *database.Options) error { + sch, err := schema.Parse(model) + if err != nil { + return err + } + + if options == nil || options.Where == nil { + options = &database.Options{Where: squirrel.Eq{model.Table() + ".id": model.Id()}} + } + + q := database.QuerierFromContext(ctx, dao.db) + _, err = sch.Delete(options, q) + return err +} diff --git a/dao/common_test.go b/dao/common_test.go new file mode 100644 index 0000000..29b0a3f --- /dev/null +++ b/dao/common_test.go @@ -0,0 +1,617 @@ +package dao + +import ( + "context" + "database/sql" + "fmt" + "sync" + "testing" + "time" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/appFs" + "github.com/geerew/off-course/utils/logger" + "github.com/geerew/off-course/utils/pagination" + "github.com/geerew/off-course/utils/types" + "github.com/spf13/afero" + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func setup(tb testing.TB) (*DAO, context.Context) { + tb.Helper() + + // Logger + var logs []*logger.Log + var logsMux sync.Mutex + logger, _, err := logger.InitLogger(&logger.BatchOptions{ + BatchSize: 1, + WriteFn: logger.TestWriteFn(&logs, &logsMux), + }) + require.NoError(tb, err, "Failed to initialize logger") + + // DB + dbManager, err := database.NewSqliteDBManager(&database.DatabaseConfig{ + IsDebug: false, + DataDir: "./oc_data", + AppFs: appFs.NewAppFs(afero.NewMemMapFs(), logger), + InMemory: true, + }) + + require.NoError(tb, err) + require.NotNil(tb, dbManager) + + dao := &DAO{db: dbManager.DataDb} + + return dao, context.Background() +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_Count(t *testing.T) { + t.Run("no entries", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{} + count, err := dao.Count(ctx, course, nil) + require.NoError(t, err) + require.Zero(t, count) + }) + + t.Run("entries", func(t *testing.T) { + dao, ctx := setup(t) + + for i := range 5 { + course := &models.Course{ + Title: fmt.Sprintf("Course %d", i), + Path: fmt.Sprintf("/course-%d", i), + } + require.NoError(t, dao.CreateCourse(ctx, course)) + } + + course := &models.Course{} + count, err := dao.Count(ctx, course, nil) + require.NoError(t, err) + require.Equal(t, count, 5) + }) + + t.Run("where", func(t *testing.T) { + dao, ctx := setup(t) + + courses := []*models.Course{} + for i := range 3 { + course := &models.Course{ + Title: fmt.Sprintf("Course %d", i), + Path: fmt.Sprintf("/course-%d", i), + } + require.NoError(t, dao.Create(ctx, course)) + courses = append(courses, course) + } + + course := &models.Course{} + + // ---------------------------- + // EQUALS ID + // ---------------------------- + count, err := dao.Count(ctx, course, &database.Options{Where: squirrel.Eq{course.Table() + ".id": courses[1].ID}}) + require.NoError(t, err) + require.Equal(t, 1, count) + + // ---------------------------- + // NOT EQUALS ID + // ---------------------------- + count, err = dao.Count(ctx, course, &database.Options{Where: squirrel.NotEq{course.Table() + ".id": courses[1].ID}}) + require.NoError(t, err) + require.Equal(t, 2, count) + + // ---------------------------- + // ERROR + // ---------------------------- + count, err = dao.Count(ctx, course, &database.Options{Where: squirrel.Eq{"": ""}}) + require.ErrorContains(t, err, "syntax error") + require.Zero(t, count) + }) + + t.Run("invalid model", func(t *testing.T) { + dao, ctx := setup(t) + + count, err := dao.Count(ctx, nil, nil) + require.ErrorIs(t, err, utils.ErrNilPtr) + require.Zero(t, count) + }) + + t.Run("db error", func(t *testing.T) { + dao, ctx := setup(t) + course := &models.Course{} + + _, err := dao.db.Exec("DROP TABLE IF EXISTS " + course.Table()) + require.NoError(t, err) + + _, err = dao.Count(ctx, course, nil) + require.ErrorContains(t, err, "no such table: "+course.Table()) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_Get(t *testing.T) { + t.Run("found", func(t *testing.T) { + dao, ctx := setup(t) + + // Create course + course := &models.Course{Title: "Course 1", Path: "/course-1", Available: true, CardPath: "/course-1/card-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + // Get course + courseResult := &models.Course{} + require.NoError(t, dao.Get(ctx, courseResult, nil)) + require.Equal(t, course.ID, courseResult.ID) + require.True(t, courseResult.CreatedAt.Equal(course.CreatedAt)) + require.True(t, courseResult.UpdatedAt.Equal(course.UpdatedAt)) + require.Equal(t, course.Title, courseResult.Title) + require.Equal(t, course.Path, courseResult.Path) + require.Equal(t, course.CardPath, courseResult.CardPath) + require.True(t, courseResult.Available) + require.Empty(t, courseResult.ScanStatus) + require.NotEmpty(t, courseResult.Progress.ID) + require.Equal(t, course.ID, courseResult.Progress.CourseID) + require.False(t, courseResult.Progress.Started) + require.Zero(t, courseResult.Progress.Percent) + require.True(t, courseResult.Progress.StartedAt.IsZero()) + require.True(t, courseResult.Progress.CompletedAt.IsZero()) + + // Create scan + scan := &models.Scan{CourseID: course.ID} + require.NoError(t, dao.CreateScan(ctx, scan)) + + // Get scan + scanResult := &models.Scan{} + require.NoError(t, dao.Get(ctx, scanResult, nil)) + require.Equal(t, scan.ID, scanResult.ID) + require.True(t, scanResult.CreatedAt.Equal(scan.CreatedAt)) + require.True(t, scanResult.UpdatedAt.Equal(scan.UpdatedAt)) + require.Equal(t, scan.CourseID, scanResult.CourseID) + require.True(t, scanResult.Status.IsWaiting()) + require.Equal(t, course.Path, scanResult.CoursePath) + + // Get course (again) + courseResult = &models.Course{} + require.NoError(t, dao.Get(ctx, courseResult, nil)) + require.Equal(t, course.ID, courseResult.ID) + require.True(t, courseResult.ScanStatus.IsWaiting()) + + // Create asset + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + require.NoError(t, dao.CreateAsset(ctx, asset)) + + // Get asset + assetResult := &models.Asset{} + require.NoError(t, dao.Get(ctx, assetResult, nil)) + require.Equal(t, asset.ID, assetResult.ID) + require.True(t, assetResult.CreatedAt.Equal(asset.CreatedAt)) + require.True(t, assetResult.UpdatedAt.Equal(asset.UpdatedAt)) + require.Equal(t, asset.CourseID, assetResult.CourseID) + require.Equal(t, asset.Title, assetResult.Title) + require.Equal(t, asset.Prefix, assetResult.Prefix) + require.Equal(t, asset.Chapter, assetResult.Chapter) + require.Equal(t, asset.Type, assetResult.Type) + require.Equal(t, asset.Path, assetResult.Path) + require.Equal(t, asset.Hash, assetResult.Hash) + require.Len(t, assetResult.Attachments, 0) + + // Create attachment + attachment := &models.Attachment{ + AssetID: asset.ID, + CourseID: course.ID, + Title: "Attachment 1", + Path: "/course-1/01 Attachment 1.txt", + } + require.NoError(t, dao.CreateAttachment(ctx, attachment)) + + // Get attachment + attachmentResult := &models.Attachment{} + require.NoError(t, dao.Get(ctx, attachmentResult, nil)) + require.Equal(t, attachment.ID, attachmentResult.ID) + require.True(t, attachmentResult.CreatedAt.Equal(attachment.CreatedAt)) + require.True(t, attachmentResult.UpdatedAt.Equal(attachment.UpdatedAt)) + require.Equal(t, attachment.AssetID, attachmentResult.AssetID) + require.Equal(t, attachment.CourseID, attachmentResult.CourseID) + require.Equal(t, attachment.Title, attachmentResult.Title) + require.Equal(t, attachment.Path, attachmentResult.Path) + + // Get asset (again) + assetResult = &models.Asset{} + require.NoError(t, dao.Get(ctx, assetResult, nil)) + require.Equal(t, asset.ID, assetResult.ID) + require.Len(t, assetResult.Attachments, 1) + require.Equal(t, attachment.Title, assetResult.Attachments[0].Title) + + // Create tag + tag := &models.Tag{Tag: "Tag 1"} + require.NoError(t, dao.CreateTag(ctx, tag)) + + // Get tag + tagResult := &models.Tag{} + require.NoError(t, dao.Get(ctx, tagResult, nil)) + require.Equal(t, tag.ID, tagResult.ID) + require.True(t, tagResult.CreatedAt.Equal(tag.CreatedAt)) + require.True(t, tagResult.UpdatedAt.Equal(tag.UpdatedAt)) + require.Equal(t, tag.Tag, tagResult.Tag) + require.Len(t, tagResult.CourseTags, 0) + + // Create course tag + courseTag := &models.CourseTag{TagID: tag.ID, CourseID: course.ID} + require.NoError(t, dao.CreateCourseTag(ctx, courseTag)) + + // Get course tag + courseTagResult := &models.CourseTag{} + require.NoError(t, dao.Get(ctx, courseTagResult, nil)) + require.Equal(t, courseTag.ID, courseTagResult.ID) + require.True(t, courseTagResult.CreatedAt.Equal(courseTag.CreatedAt)) + require.True(t, courseTagResult.UpdatedAt.Equal(courseTag.UpdatedAt)) + require.Equal(t, courseTag.TagID, courseTagResult.TagID) + require.Equal(t, courseTag.CourseID, courseTagResult.CourseID) + require.Equal(t, course.Title, courseTagResult.Course) + require.Equal(t, tag.Tag, courseTagResult.Tag) + + // Get tag (again) + tagResult = &models.Tag{} + require.NoError(t, dao.Get(ctx, tagResult, nil)) + require.Equal(t, tag.ID, tagResult.ID) + require.Len(t, tagResult.CourseTags, 1) + + // Create user + user := &models.User{Username: "user1", PasswordHash: "1234", Role: types.UserRoleAdmin} + require.NoError(t, dao.CreateUser(ctx, user)) + + // Get user + userResult := &models.User{} + require.NoError(t, dao.Get(ctx, userResult, nil)) + require.Equal(t, user.ID, userResult.ID) + require.True(t, userResult.CreatedAt.Equal(user.CreatedAt)) + require.True(t, userResult.UpdatedAt.Equal(user.UpdatedAt)) + require.Equal(t, user.Username, userResult.Username) + require.Equal(t, user.PasswordHash, userResult.PasswordHash) + require.Equal(t, user.Role, userResult.Role) + }) + + t.Run("not found", func(t *testing.T) { + dao, ctx := setup(t) + course := &models.Course{} + err := dao.Get(ctx, course, &database.Options{Where: squirrel.Eq{course.Table() + ".path": "1234"}}) + require.ErrorIs(t, err, sql.ErrNoRows) + }) + + t.Run("where", func(t *testing.T) { + dao, ctx := setup(t) + + courses := []*models.Course{} + for i := range 3 { + course := &models.Course{ + Title: fmt.Sprintf("Course %d", i), + Path: fmt.Sprintf("/course-%d", i), + } + require.NoError(t, dao.Create(ctx, course)) + courses = append(courses, course) + time.Sleep(1 * time.Millisecond) + } + + courseResult := &models.Course{} + require.NoError(t, dao.Get(ctx, courseResult, &database.Options{Where: squirrel.Eq{models.COURSE_TABLE + ".path": courses[1].Path}})) + require.Equal(t, courses[1].ID, courseResult.ID) + }) + + t.Run("orderby", func(t *testing.T) { + dao, ctx := setup(t) + + courses := []*models.Course{} + for i := range 3 { + course := &models.Course{ + Title: fmt.Sprintf("Course %d", i), + Path: fmt.Sprintf("/course-%d", i), + } + require.NoError(t, dao.Create(ctx, course)) + courses = append(courses, course) + time.Sleep(1 * time.Millisecond) + } + + result := &models.Course{} + options := &database.Options{OrderBy: []string{fmt.Sprintf("%s.title DESC", models.COURSE_TABLE)}} + require.NoError(t, dao.Get(ctx, result, options)) + require.Equal(t, courses[2].ID, result.ID) + }) + + t.Run("invalid model", func(t *testing.T) { + dao, ctx := setup(t) + err := dao.Get(ctx, nil, nil) + require.ErrorIs(t, err, utils.ErrNilPtr) + }) + + t.Run("invalid where", func(t *testing.T) { + dao, ctx := setup(t) + err := dao.Get(ctx, &models.Course{}, &database.Options{Where: squirrel.Eq{"`": "`"}}) + require.ErrorContains(t, err, "SQL logic error: unrecognized token") + }) + + t.Run("db error", func(t *testing.T) { + dao, ctx := setup(t) + course := &models.Course{} + + _, err := dao.db.Exec("DROP TABLE IF EXISTS " + course.Table()) + require.NoError(t, err) + + err = dao.Get(ctx, course, nil) + require.ErrorContains(t, err, "no such table: "+course.Table()) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_GetById(t *testing.T) { + t.Run("found", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1", Available: true, CardPath: "/course-1/card-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + courseResult := &models.Course{Base: models.Base{ID: course.ID}} + require.NoError(t, dao.GetById(ctx, courseResult)) + require.Equal(t, course.ID, courseResult.ID) + }) + + t.Run("not found", func(t *testing.T) { + dao, ctx := setup(t) + err := dao.GetById(ctx, &models.Course{Base: models.Base{ID: "1234"}}) + require.ErrorIs(t, err, sql.ErrNoRows) + }) + + t.Run("invalid model", func(t *testing.T) { + dao, ctx := setup(t) + err := dao.GetById(ctx, nil) + require.ErrorIs(t, err, utils.ErrNilPtr) + }) + + t.Run("invalid id", func(t *testing.T) { + dao, ctx := setup(t) + err := dao.GetById(ctx, &models.Course{}) + require.ErrorIs(t, err, utils.ErrInvalidId) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_List(t *testing.T) { + t.Run("no entries", func(t *testing.T) { + dao, ctx := setup(t) + + courses := []*models.Course{} + err := dao.List(ctx, &courses, nil) + require.NoError(t, err) + require.Empty(t, courses) + }) + + t.Run("entries", func(t *testing.T) { + dao, ctx := setup(t) + + for i := range 5 { + course := &models.Course{ + Title: fmt.Sprintf("Course %d", i), + Path: fmt.Sprintf("/course-%d", i), + } + require.NoError(t, dao.Create(ctx, course)) + } + + courses := []*models.Course{} + err := dao.List(ctx, &courses, nil) + require.NoError(t, err) + require.Len(t, courses, 5) + }) + + t.Run("pagination", func(t *testing.T) { + dao, ctx := setup(t) + + courses := []*models.Course{} + for i := range 17 { + course := &models.Course{ + Title: fmt.Sprintf("Course %d", i), + Path: fmt.Sprintf("/course-%d", i), + } + require.NoError(t, dao.Create(ctx, course)) + courses = append(courses, course) + time.Sleep(1 * time.Millisecond) + } + + coursesResult := []*models.Course{} + + // Page 1 (10 items) + p := pagination.New(1, 10) + require.NoError(t, dao.List(ctx, &coursesResult, &database.Options{Pagination: p})) + require.Len(t, coursesResult, 10) + require.Equal(t, 17, p.TotalItems()) + require.Equal(t, courses[0].ID, coursesResult[0].ID) + require.Equal(t, courses[9].ID, coursesResult[9].ID) + + // Page 2 (7 items) + p = pagination.New(2, 10) + require.NoError(t, dao.List(ctx, &coursesResult, &database.Options{Pagination: p})) + require.Len(t, coursesResult, 7) + require.Equal(t, 17, p.TotalItems()) + require.Equal(t, courses[10].ID, coursesResult[0].ID) + require.Equal(t, courses[16].ID, coursesResult[6].ID) + }) + + t.Run("orderby", func(t *testing.T) { + dao, ctx := setup(t) + + courses := []*models.Course{} + for i := range 3 { + course := &models.Course{ + Title: fmt.Sprintf("Course %d", i), + Path: fmt.Sprintf("/course-%d", i), + } + require.NoError(t, dao.Create(ctx, course)) + courses = append(courses, course) + time.Sleep(1 * time.Millisecond) + } + + // CREATED_AT DESC + coursesResult := []*models.Course{} + options := &database.Options{OrderBy: []string{models.COURSE_TABLE + ".title DESC"}} + require.NoError(t, dao.List(ctx, &coursesResult, options)) + require.Len(t, coursesResult, 3) + require.Equal(t, courses[2].ID, coursesResult[0].ID) + + // // // ---------------------------- + // // // SCAN_STATUS DESC + // // // ---------------------------- + + // // // Create a scan for course 2 and 3 + // // scanDao := NewScanDao(db) + + // // testData[1].Scan = &models.Scan{CourseID: testData[1].ID} + // // require.Nil(t, scanDao.Create(testData[1].Scan, nil)) + // // testData[2].Scan = &models.Scan{CourseID: testData[2].ID} + // // require.Nil(t, scanDao.Create(testData[2].Scan, nil)) + + // // // Set course 3 to processing + // // testData[2].Scan.Status = types.NewScanStatus(types.ScanStatusProcessing) + // // require.Nil(t, scanDao.Update(testData[2].Scan, nil)) + + // // result, err = dao.List(&database.DatabaseParams{OrderBy: []string{dao.Table() + ".scan_status desc"}}, nil) + // // require.Nil(t, err) + // // require.Len(t, result, 3) + + // // require.Equal(t, testData[0].ID, result[2].ID) + // // require.Equal(t, testData[1].ID, result[1].ID) + // // require.Equal(t, testData[2].ID, result[0].ID) + + // // // ---------------------------- + // // // SCAN_STATUS ASC + // // // ---------------------------- + // // result, err = dao.List(&database.DatabaseParams{OrderBy: []string{dao.Table() + ".scan_status asc"}}, nil) + // // require.Nil(t, err) + // // require.Len(t, result, 3) + + // // require.Equal(t, testData[0].ID, result[0].ID) + // // require.Equal(t, testData[1].ID, result[1].ID) + // // require.Equal(t, testData[2].ID, result[2].ID) + + }) + + t.Run("where", func(t *testing.T) { + dao, ctx := setup(t) + + courses := []*models.Course{} + for i := range 3 { + course := &models.Course{ + Title: fmt.Sprintf("Course %d", i), + Path: fmt.Sprintf("/course-%d", i), + } + require.NoError(t, dao.Create(ctx, course)) + courses = append(courses, course) + time.Sleep(1 * time.Millisecond) + } + + // Equals ID or ID + coursesResult := []*models.Course{} + options := &database.Options{ + Where: squirrel.Or{ + squirrel.Eq{models.COURSE_TABLE + ".id": courses[1].ID}, + squirrel.Eq{models.COURSE_TABLE + ".id": courses[2].ID}, + }, + OrderBy: []string{models.COURSE_TABLE + ".created_at ASC"}, + } + require.NoError(t, dao.List(ctx, &coursesResult, options)) + require.Len(t, coursesResult, 2) + require.Equal(t, courses[1].ID, coursesResult[0].ID) + require.Equal(t, courses[2].ID, coursesResult[1].ID) + }) + + t.Run("invalid", func(t *testing.T) { + dao, ctx := setup(t) + + // Nil + require.ErrorIs(t, dao.List(ctx, nil, nil), utils.ErrNilPtr) + + // Not a pointer + require.ErrorIs(t, dao.List(ctx, []*models.Course{}, nil), utils.ErrNotPtr) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_Delete(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.Create(ctx, course)) + + require.NoError(t, dao.Delete(ctx, course, &database.Options{Where: squirrel.Eq{course.Table() + ".path": course.Path}})) + }) + + t.Run("nil model", func(t *testing.T) { + dao, ctx := setup(t) + err := dao.Delete(ctx, nil, &database.Options{Where: squirrel.Eq{"path": "1234"}}) + require.ErrorIs(t, err, utils.ErrNilPtr) + }) + + t.Run("nil where", func(t *testing.T) { + dao, ctx := setup(t) + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.Create(ctx, course)) + require.NoError(t, dao.Delete(ctx, course, nil)) + + // Check if it was deleted + courseResult := &models.Course{Base: models.Base{ID: course.ID}} + err := dao.GetById(ctx, courseResult) + require.ErrorIs(t, err, sql.ErrNoRows) + }) + + t.Run("db error", func(t *testing.T) { + dao, ctx := setup(t) + course := &models.Course{} + + _, err := dao.db.Exec("DROP TABLE IF EXISTS " + course.Table()) + require.NoError(t, err) + + err = dao.Delete(ctx, course, &database.Options{Where: squirrel.Eq{"path": "1234"}}) + require.ErrorContains(t, err, "no such table: "+course.Table()) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Benchmark_GetById(b *testing.B) { + dao, ctx := setup(b) + + for i := 0; i < 1000; i++ { + course := &models.Course{} + course.ID = fmt.Sprintf("%d", i) + course.Title = fmt.Sprintf("Course %d", i) + course.Path = fmt.Sprintf("/course-%d", i) + require.NoError(b, dao.CreateCourse(ctx, course)) + + courseProgress := &models.CourseProgress{} + require.NoError(b, dao.Get(ctx, courseProgress, &database.Options{Where: squirrel.Eq{courseProgress.Table() + ".course_id": course.ID}})) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + courseResult := &models.Course{Base: models.Base{ID: fmt.Sprintf("%d", (i % 1000))}} + require.NoError(b, dao.GetById(ctx, courseResult)) + } +} diff --git a/dao/course.go b/dao/course.go new file mode 100644 index 0000000..5e4c82f --- /dev/null +++ b/dao/course.go @@ -0,0 +1,176 @@ +package dao + +import ( + "context" + "slices" + "strings" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/types" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CreateCourse creates a course and course progress +func (dao *DAO) CreateCourse(ctx context.Context, course *models.Course) error { + if course == nil { + return utils.ErrNilPtr + } + + return dao.db.RunInTransaction(ctx, func(txCtx context.Context) error { + err := dao.Create(txCtx, course) + if err != nil { + return err + } + + courseProgress := &models.CourseProgress{CourseID: course.Id()} + return dao.CreateCourseProgress(txCtx, courseProgress) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// UpdateCourse updates a course +func (dao *DAO) UpdateCourse(ctx context.Context, course *models.Course) error { + if course == nil { + return utils.ErrNilPtr + } + + _, err := dao.Update(ctx, course) + return err +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// ClassifyCoursePaths classifies the given paths into one of the following categories: +// - PathClassificationNone: The path does not exist in the courses table +// - PathClassificationAncestor: The path is an ancestor of a course path +// - PathClassificationCourse: The path is an exact match to a course path +// - PathClassificationDescendant: The path is a descendant of a course path +// +// The paths are returned as a map with the original path as the key and the classification as the +// value +func (dao *DAO) ClassifyCoursePaths(ctx context.Context, paths []string) (map[string]types.PathClassification, error) { + course := &models.Course{} + + paths = slices.DeleteFunc(paths, func(s string) bool { + return s == "" + }) + + if len(paths) == 0 { + return nil, nil + } + + results := make(map[string]types.PathClassification) + for _, path := range paths { + results[path] = types.PathClassificationNone + } + + whereClause := make([]squirrel.Sqlizer, len(paths)) + for i, path := range paths { + whereClause[i] = squirrel.Like{course.Table() + ".path": path + "%"} + } + + query, args, _ := squirrel. + StatementBuilder. + Select(course.Table() + ".path"). + From(course.Table()). + Where(squirrel.Or(whereClause)). + ToSql() + + q := database.QuerierFromContext(ctx, dao.db) + rows, err := q.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var coursePath string + coursePaths := []string{} + for rows.Next() { + if err := rows.Scan(&coursePath); err != nil { + return nil, err + } + coursePaths = append(coursePaths, coursePath) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + // Process + for _, path := range paths { + for _, coursePath := range coursePaths { + if coursePath == path { + results[path] = types.PathClassificationCourse + break + } else if strings.HasPrefix(coursePath, path) { + results[path] = types.PathClassificationAncestor + break + } else if strings.HasPrefix(path, coursePath) && results[path] != types.PathClassificationAncestor { + results[path] = types.PathClassificationDescendant + break + } + } + } + + return results, nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// // ProcessOrderBy takes an array of strings representing orderBy clauses and returns a processed +// // version of this array +// // +// // It will creates a new list of valid table columns based upon columns() for the current +// // DAO. Additionally, it handles the special case of 'scan_status' column, which requires custom +// // sorting logic, via a CASE statement. +// // +// // The custom sorting logic is defined as follows: +// // - NULL values are treated as the lowest value (sorted first in ASC, last in DESC) +// // - 'waiting' status is treated as the second value +// // - 'processing' status is treated as the third value +// func (dao *CourseDao) ProcessOrderBy(orderBy []string, validOrderByColumns []string) []string { +// if len(orderBy) == 0 { +// return orderBy +// } + +// var processedOrderBy []string + +// for _, ob := range orderBy { +// t, c := extractTableAndColumn(ob) + +// // Prefix the table with the dao's table if not found +// if t == "" { +// t = dao.Table() +// ob = t + "." + ob +// } + +// if isValidOrderBy(t, c, validOrderByColumns) { +// // When the column is 'scan_status', apply the custom sorting logic +// if c == "scan_status" { +// // Determine the sort direction, defaulting to ASC if not specified +// parts := strings.Fields(ob) +// sortDirection := "ASC" +// if len(parts) > 1 { +// sortDirection = strings.ToUpper(parts[1]) +// } + +// caseStmt := "CASE " + +// "WHEN scan_status IS NULL THEN 1 " + +// "WHEN scan_status = 'waiting' THEN 2 " + +// "WHEN scan_status = 'processing' THEN 3 " + +// "END " + sortDirection + +// processedOrderBy = append(processedOrderBy, caseStmt) +// } else { +// processedOrderBy = append(processedOrderBy, ob) +// } +// } +// } + +// return processedOrderBy +// } diff --git a/dao/course_progress.go b/dao/course_progress.go new file mode 100644 index 0000000..a195677 --- /dev/null +++ b/dao/course_progress.go @@ -0,0 +1,105 @@ +package dao + +import ( + "context" + "database/sql" + "math" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/types" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CreateCourseProgress creates a course progress +func (dao *DAO) CreateCourseProgress(ctx context.Context, courseProgress *models.CourseProgress) error { + if courseProgress == nil { + return utils.ErrNilPtr + } + + return dao.Create(ctx, courseProgress) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// UpdateCourseProgress update a course progress +func (dao *DAO) UpdateCourseProgress(ctx context.Context, courseProgress *models.CourseProgress) error { + if courseProgress == nil { + return utils.ErrNilPtr + } + + _, err := dao.Update(ctx, courseProgress) + return err +} + +// Refresh refreshes the current course progress for the given ID +// +// It calculates the number of assets, number of completed assets and number of started video assets, +// then calculates the percent complete and whether the course has been started +// +// Based upon this calculation, +// - If the course has been started and `started_at` is null, `started_at` will be set to NOW +// - If the course is not started, `started_at` is set to null +// - If the course is complete and `completed_at` is null, `completed_at` is set to NOW +// - If the course is not complete, `completed_at` is set to null +func (dao *DAO) RefreshCourseProgress(ctx context.Context, courseID string) error { + if courseID == "" { + return utils.ErrInvalidId + } + + // Count the number of assets, number of completed assets and number of video assets started for + // this course + query, args, _ := squirrel. + StatementBuilder. + PlaceholderFormat(squirrel.Question). + Select( + "COUNT(DISTINCT "+models.ASSET_TABLE+".id) AS total_count", + "SUM(CASE WHEN "+models.ASSET_PROGRESS_TABLE+".completed THEN 1 ELSE 0 END) AS completed_count", + "SUM(CASE WHEN "+models.ASSET_PROGRESS_TABLE+".video_pos > 0 THEN 1 ELSE 0 END) AS started_count"). + From(models.ASSET_TABLE). + LeftJoin(models.ASSET_PROGRESS_TABLE + " ON " + models.ASSET_TABLE + ".id = " + models.ASSET_PROGRESS_TABLE + ".asset_id"). + Where(squirrel.And{squirrel.Eq{models.ASSET_TABLE + ".course_id": courseID}}). + ToSql() + + var totalAssetCount sql.NullInt32 + var completedAssetCount sql.NullInt32 + var startedAssetCount sql.NullInt32 + + q := database.QuerierFromContext(ctx, dao.db) + err := q.QueryRow(query, args...).Scan(&totalAssetCount, &completedAssetCount, &startedAssetCount) + if err != nil { + return err + } + + // Get the course progress + courseProgress := &models.CourseProgress{} + err = dao.Get(ctx, courseProgress, &database.Options{Where: squirrel.Eq{courseProgress.Table() + ".course_id": courseID}}) + if err != nil { + return err + } + + now := types.NowDateTime() + + courseProgress.Percent = int(math.Abs((float64(completedAssetCount.Int32) * float64(100)) / float64(totalAssetCount.Int32))) + + if startedAssetCount.Int32 > 0 || courseProgress.Percent > 0 && courseProgress.Percent <= 100 { + courseProgress.Started = true + courseProgress.StartedAt = now + } else { + courseProgress.Started = false + courseProgress.StartedAt = types.DateTime{} + } + + if courseProgress.Percent == 100 { + courseProgress.CompletedAt = now + } else { + courseProgress.CompletedAt = types.DateTime{} + } + + // Update the course progress + _, err = dao.Update(ctx, courseProgress) + return err +} diff --git a/dao/course_progress_test.go b/dao/course_progress_test.go new file mode 100644 index 0000000..30b29db --- /dev/null +++ b/dao/course_progress_test.go @@ -0,0 +1,133 @@ +package dao + +import ( + "database/sql" + "testing" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/types" + "github.com/stretchr/testify/require" +) + +func Test_CreateCourseProgress(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.Create(ctx, course)) + + courseProgress := &models.CourseProgress{CourseID: course.ID} + require.NoError(t, dao.CreateCourseProgress(ctx, courseProgress)) + }) + + t.Run("nil pointer", func(t *testing.T) { + dao, ctx := setup(t) + require.ErrorIs(t, dao.CreateCourseProgress(ctx, nil), utils.ErrNilPtr) + }) + + t.Run("invalid course id", func(t *testing.T) { + dao, ctx := setup(t) + courseProgress := &models.CourseProgress{CourseID: "invalid"} + require.ErrorContains(t, dao.CreateCourseProgress(ctx, courseProgress), "FOREIGN KEY constraint failed") + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_RefreshCourseProgress(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + require.False(t, course.Progress.Started) + require.True(t, course.Progress.StartedAt.IsZero()) + require.Zero(t, course.Progress.Percent) + require.True(t, course.Progress.CompletedAt.IsZero()) + + // Create asset + asset1 := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/01 asset.mp4", + Hash: "1234", + } + require.NoError(t, dao.CreateAsset(ctx, asset1)) + + // Set asset1 progress (video_pos > 0) + assetProgress := &models.AssetProgress{AssetID: asset1.ID, CourseID: course.ID, VideoPos: 1} + require.NoError(t, dao.CreateOrUpdateAssetProgress(ctx, assetProgress)) + + require.NoError(t, dao.GetById(ctx, course)) + require.True(t, course.Progress.Started) + require.False(t, course.Progress.StartedAt.IsZero()) + require.Zero(t, 0, course.Progress.Percent) + require.True(t, course.Progress.CompletedAt.IsZero()) + + // Set asset progress (video_pos = 0) + assetProgress.VideoPos = 0 + require.NoError(t, dao.CreateOrUpdateAssetProgress(ctx, assetProgress)) + + require.NoError(t, dao.GetById(ctx, course)) + require.False(t, course.Progress.Started) + require.True(t, course.Progress.StartedAt.IsZero()) + require.Zero(t, 0, course.Progress.Percent) + require.True(t, course.Progress.CompletedAt.IsZero()) + + // Set asset progress (completed = true) + assetProgress.Completed = true + require.NoError(t, dao.CreateOrUpdateAssetProgress(ctx, assetProgress)) + + require.NoError(t, dao.GetById(ctx, course)) + require.True(t, course.Progress.Started) + require.False(t, course.Progress.StartedAt.IsZero()) + require.Equal(t, 100, course.Progress.Percent) + require.False(t, course.Progress.CompletedAt.IsZero()) + + // Add another asset + asset2 := &models.Asset{ + CourseID: course.ID, + Title: "Asset 2", + Prefix: sql.NullInt16{Int16: 2, Valid: true}, + Chapter: "Chapter 2", + Type: *types.NewAsset("mp4"), + Path: "/course-1/02 asset.mp4", + Hash: "5678", + } + require.NoError(t, dao.CreateAsset(ctx, asset2)) + + // Check course progress + require.NoError(t, dao.GetById(ctx, course)) + require.True(t, course.Progress.Started) + require.False(t, course.Progress.StartedAt.IsZero()) + require.Equal(t, 50, course.Progress.Percent) + require.True(t, course.Progress.CompletedAt.IsZero()) + }) + + t.Run("invalid course id", func(t *testing.T) { + dao, ctx := setup(t) + require.ErrorIs(t, dao.RefreshCourseProgress(ctx, ""), utils.ErrInvalidId) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_CourseProgressDeleteCascade(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course", Path: "/course"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + courseProgress := &models.CourseProgress{CourseID: course.ID} + require.NoError(t, dao.Get(ctx, courseProgress, &database.Options{Where: squirrel.Eq{courseProgress.Table() + ".course_id": course.ID}})) + + require.NoError(t, dao.Delete(ctx, courseProgress, nil)) + + err := dao.Get(ctx, courseProgress, &database.Options{Where: squirrel.Eq{courseProgress.Table() + ".course_id": course.ID}}) + require.ErrorIs(t, err, sql.ErrNoRows) +} diff --git a/dao/course_tag.go b/dao/course_tag.go new file mode 100644 index 0000000..ac7888b --- /dev/null +++ b/dao/course_tag.go @@ -0,0 +1,96 @@ +package dao + +import ( + "context" + "database/sql" + "fmt" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CreateCourseTag creates a course tag +func (dao *DAO) CreateCourseTag(ctx context.Context, courseTag *models.CourseTag) error { + if courseTag == nil { + return utils.ErrNilPtr + } + + if courseTag.TagID == "" && courseTag.Tag == "" { + return fmt.Errorf("tag ID and tag cannot be empty") + } + + return dao.db.RunInTransaction(ctx, func(txCtx context.Context) error { + if courseTag.TagID != "" { + return dao.Create(txCtx, courseTag) + } + + // Get the tag by tag name + tag := models.Tag{} + err := dao.Get(txCtx, &tag, &database.Options{Where: squirrel.Eq{models.TAG_TABLE + ".tag": courseTag.Tag}}) + if err != nil && err != sql.ErrNoRows { + return err + } + + // If the tag does not exist, create it + if err == sql.ErrNoRows { + tag.Tag = courseTag.Tag + err = dao.Create(txCtx, &tag) + if err != nil { + return err + } + } + + courseTag.TagID = tag.ID + + return dao.Create(txCtx, courseTag) + + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// // ListCourseIdsByTags lists course IDs containing all tags in the slice +// func (dao *DAO) ListCourseIdsByTags(tags []string, dbParams *database.DatabaseParams, tx *database.Tx) ([]string, error) { +// if len(tags) == 0 { +// return nil, nil +// } + +// if dbParams == nil { +// dbParams = &database.DatabaseParams{} +// } + +// selectColumns, _ := tableColumnsOrPanic(models.CourseTag{}, dao.Table()) + +// dbParams.OrderBy = genericProcessOrderBy(dbParams.OrderBy, selectColumns, dao, false) +// dbParams.Columns = []string{dao.Table() + ".course_id"} +// dbParams.Where = squirrel.Eq{NewTagDao(dao.db).Table() + ".tag": tags} +// dbParams.GroupBys = []string{dao.Table() + ".course_id"} +// dbParams.Having = squirrel.Expr("COUNT(DISTINCT "+NewTagDao(dao.db).Table()+".tag) = ?", len(tags)) +// dbParams.Pagination = nil + +// rows, err := genericListWithoutScan(dao, dbParams, tx) +// if err != nil { +// return nil, err +// } +// defer rows.Close() + +// var courseIds []string +// for rows.Next() { +// var courseId string +// if err := rows.Scan(&courseId); err != nil { +// return nil, err +// } + +// courseIds = append(courseIds, courseId) +// } + +// if err := rows.Err(); err != nil { +// return nil, err +// } + +// return courseIds, nil +// } diff --git a/dao/course_tag_test.go b/dao/course_tag_test.go new file mode 100644 index 0000000..2973175 --- /dev/null +++ b/dao/course_tag_test.go @@ -0,0 +1,293 @@ +package dao + +import ( + "database/sql" + "fmt" + "testing" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/stretchr/testify/require" +) + +func Test_CreateCourseTag(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + courses := []*models.Course{} + for i := range 2 { + course := &models.Course{Title: fmt.Sprintf("Course %d", i), Path: fmt.Sprintf("/course-%d", i)} + require.NoError(t, dao.CreateCourse(ctx, course)) + courses = append(courses, course) + } + + tag := &models.Tag{Tag: "Go"} + require.NoError(t, dao.CreateTag(ctx, tag)) + + // Using ID + courseTagByID := &models.CourseTag{TagID: tag.ID, CourseID: courses[0].ID} + require.Nil(t, dao.CreateCourseTag(ctx, courseTagByID)) + + // Using Tag + courseTagByTag := &models.CourseTag{CourseID: courses[1].ID, Tag: "Go"} + require.Nil(t, dao.CreateCourseTag(ctx, courseTagByTag)) + + // Create (tag does not exist) + courseTagCreated := &models.CourseTag{CourseID: courses[1].ID, Tag: "TypeScript"} + require.Nil(t, dao.CreateCourseTag(ctx, courseTagCreated)) + }) + + t.Run("nil pointer", func(t *testing.T) { + dao, ctx := setup(t) + + require.ErrorIs(t, dao.CreateCourseTag(ctx, nil), utils.ErrNilPtr) + }) + + t.Run("invalid tag ID", func(t *testing.T) { + dao, ctx := setup(t) + + courseTag := &models.CourseTag{TagID: "invalid", CourseID: "invalid"} + require.ErrorContains(t, dao.CreateCourseTag(ctx, courseTag), "FOREIGN KEY constraint failed") + }) + + t.Run("invalid course ID", func(t *testing.T) { + dao, ctx := setup(t) + + tag := &models.Tag{Tag: "Go"} + require.NoError(t, dao.CreateTag(ctx, tag)) + + courseTag := &models.CourseTag{TagID: tag.ID, CourseID: "invalid"} + require.ErrorContains(t, dao.CreateCourseTag(ctx, courseTag), "FOREIGN KEY constraint failed") + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_CourseTagDeleteCascade(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course", Path: "/course"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + courseTag := &models.CourseTag{CourseID: course.ID, Tag: "Tag 1"} + require.NoError(t, dao.CreateCourseTag(ctx, courseTag)) + + require.Nil(t, dao.Delete(ctx, course, nil)) + + err := dao.GetById(ctx, courseTag) + require.ErrorIs(t, err, sql.ErrNoRows) +} + +// // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// func TestCourseTag_ListCourseIdsByTags(t *testing.T) { +// t.Run("no entries", func(t *testing.T) { +// dao, _ := courseTagSetup(t) + +// tags, err := dao.ListCourseIdsByTags([]string{"1234"}, nil, nil) +// require.Nil(t, err) +// require.Zero(t, tags) +// }) + +// t.Run("found", func(t *testing.T) { +// dao, _ := courseTagSetup(t) + +// course1 := NewTestBuilder(t).Db(dao.db).Courses([]string{"course 1"}).Tags([]string{"Go", "Data Structures"}).Build()[0] +// course2 := NewTestBuilder(t).Db(dao.db).Courses([]string{"course 2"}).Tags([]string{"Data Structures", "TypeScript", "PHP"}).Build()[0] +// course3 := NewTestBuilder(t).Db(dao.db).Courses([]string{"course 3"}).Tags([]string{"Go", "Data Structures", "PHP"}).Build()[0] + +// // Order by title (asc) +// dbParams := &database.DatabaseParams{OrderBy: []string{NewCourseDao(dao.db).Table() + ".title asc"}} + +// // Go +// result, err := dao.ListCourseIdsByTags([]string{"Go"}, dbParams, nil) +// require.Nil(t, err) +// require.Len(t, result, 2) +// require.Equal(t, course1.ID, result[0]) +// require.Equal(t, course3.ID, result[1]) + +// // Go, Data Structures +// result, err = dao.ListCourseIdsByTags([]string{"Go", "Data Structures"}, dbParams, nil) +// require.Nil(t, err) +// require.Len(t, result, 2) +// require.Equal(t, course1.ID, result[0]) +// require.Equal(t, course3.ID, result[1]) + +// // Go, Data Structures, PHP +// result, err = dao.ListCourseIdsByTags([]string{"Go", "Data Structures", "PHP"}, dbParams, nil) +// require.Nil(t, err) +// require.Len(t, result, 1) +// require.Equal(t, course3.ID, result[0]) + +// // Go, Data Structures, PHP, TypeScript +// result, err = dao.ListCourseIdsByTags([]string{"Go", "Data Structures", "PHP", "TypeScript"}, dbParams, nil) +// require.Nil(t, err) +// require.Len(t, result, 0) + +// // Data Structures +// result, err = dao.ListCourseIdsByTags([]string{"Data Structures"}, dbParams, nil) +// require.Nil(t, err) +// require.Len(t, result, 3) +// require.Equal(t, course1.ID, result[0]) +// require.Equal(t, course2.ID, result[1]) +// require.Equal(t, course3.ID, result[2]) + +// }) + +// t.Run("db error", func(t *testing.T) { +// dao, db := courseTagSetup(t) + +// _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) +// require.Nil(t, err) + +// _, err = dao.ListCourseIdsByTags([]string{"1234"}, nil, nil) +// require.ErrorContains(t, err, "no such table: "+dao.Table()) +// }) +// } + +// // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// func TestCourseTag_List(t *testing.T) { +// t.Run("no entries", func(t *testing.T) { +// dao, _ := courseTagSetup(t) + +// tags, err := dao.List(nil, nil) +// require.Nil(t, err) +// require.Zero(t, tags) +// }) + +// t.Run("found", func(t *testing.T) { +// dao, _ := courseTagSetup(t) + +// NewTestBuilder(t).Db(dao.db).Courses(2).Tags(5).Build() + +// result, err := dao.List(nil, nil) +// require.Nil(t, err) +// require.Len(t, result, 10) +// }) + +// t.Run("orderby", func(t *testing.T) { +// dao, _ := courseTagSetup(t) + +// NewTestBuilder(t).Db(dao.db).Courses(2).Tags([]string{"PHP", "Go", "Java", "TypeScript", "JavaScript"}).Build() +// tagDao := NewTagDao(dao.db) + +// // ---------------------------- +// // TAG DESC +// // ---------------------------- +// result, err := dao.List(&database.DatabaseParams{OrderBy: []string{tagDao.Table() + ".tag desc"}}, nil) +// require.Nil(t, err) +// require.Len(t, result, 10) +// require.Equal(t, "TypeScript", result[0].Tag) + +// // ---------------------------- +// // TAG ASC +// // ---------------------------- +// result, err = dao.List(&database.DatabaseParams{OrderBy: []string{tagDao.Table() + ".tag asc"}}, nil) +// require.Nil(t, err) +// require.Len(t, result, 10) +// require.Equal(t, "Go", result[0].Tag) +// }) + +// t.Run("where", func(t *testing.T) { +// dao, _ := courseTagSetup(t) + +// testData := NewTestBuilder(t).Db(dao.db).Courses(2).Tags([]string{"PHP", "Go", "Java", "TypeScript", "JavaScript"}).Build() + +// courseDao := NewCourseDao(dao.db) +// tagDao := NewTagDao(dao.db) + +// // ---------------------------- +// // EQUALS (course title) +// // ---------------------------- +// result, err := dao.List(&database.DatabaseParams{Where: squirrel.Eq{courseDao.Table() + ".title": testData[0].Course.Title}}, nil) +// require.Nil(t, err) +// require.Len(t, result, 5) + +// // ---------------------------- +// // Like (Java%) +// // ---------------------------- +// result, err = dao.List(&database.DatabaseParams{Where: squirrel.Like{tagDao.Table() + ".tag": "Java%"}}, nil) +// require.Nil(t, err) +// require.Len(t, result, 4) + +// // ---------------------------- +// // ERROR +// // ---------------------------- +// result, err = dao.List(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) +// require.ErrorContains(t, err, "syntax error") +// require.Nil(t, result) +// }) + +// t.Run("pagination", func(t *testing.T) { +// dao, _ := courseTagSetup(t) + +// NewTestBuilder(t).Db(dao.db).Courses(1).Tags(20).Build() + +// // ---------------------------- +// // Page 1 with 10 items +// // ---------------------------- +// p := pagination.New(1, 10) + +// result, err := dao.List(&database.DatabaseParams{Pagination: p, OrderBy: []string{"tags.tag asc"}}, nil) +// require.Nil(t, err) +// require.Len(t, result, 10) +// require.Equal(t, 20, p.TotalItems()) +// require.Equal(t, "C", result[0].Tag) + +// // ---------------------------- +// // Page 2 with 10 items +// // ---------------------------- +// p = pagination.New(2, 10) + +// result, err = dao.List(&database.DatabaseParams{Pagination: p, OrderBy: []string{"tags.tag asc"}}, nil) +// require.Nil(t, err) +// require.Len(t, result, 10) +// require.Equal(t, 20, p.TotalItems()) +// require.Equal(t, "Perl", result[0].Tag) +// }) + +// t.Run("db error", func(t *testing.T) { +// dao, db := courseTagSetup(t) + +// _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) +// require.Nil(t, err) + +// _, err = dao.List(nil, nil) +// require.ErrorContains(t, err, "no such table: "+dao.Table()) +// }) +// } + +// // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// func TestCourseTag_Delete(t *testing.T) { +// t.Run("success", func(t *testing.T) { +// dao, _ := courseTagSetup(t) + +// testData := NewTestBuilder(t).Db(dao.db).Courses(1).Tags([]string{"C", "Go", "JavaScript", "Perl"}).Build() + +// err := dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": testData[0].Tags[1].ID}}, nil) +// require.Nil(t, err) + +// tags, err := dao.List(nil, nil) +// require.Nil(t, err) +// require.Len(t, tags, 3) +// }) + +// t.Run("no db params", func(t *testing.T) { +// dao, _ := courseTagSetup(t) + +// err := dao.Delete(nil, nil) +// require.ErrorIs(t, err, ErrMissingWhere) +// }) + +// t.Run("db error", func(t *testing.T) { +// dao, db := courseTagSetup(t) + +// _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) +// require.Nil(t, err) + +// err = dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"tag": "1234"}}, nil) +// require.ErrorContains(t, err, "no such table: "+dao.Table()) +// }) +// } diff --git a/dao/course_test.go b/dao/course_test.go new file mode 100644 index 0000000..c2de5e1 --- /dev/null +++ b/dao/course_test.go @@ -0,0 +1,146 @@ +package dao + +import ( + "fmt" + "testing" + "time" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/types" + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_CreateCourse(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + }) + + t.Run("nil", func(t *testing.T) { + dao, ctx := setup(t) + require.ErrorIs(t, dao.CreateCourse(ctx, nil), utils.ErrNilPtr) + }) + + t.Run("duplicate", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Base: models.Base{ID: "1"}, Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + // Duplicate ID + course = &models.Course{Base: models.Base{ID: "1"}, Title: "Course 2", Path: "/course-2"} + require.ErrorContains(t, dao.CreateCourse(ctx, course), "UNIQUE constraint failed: "+models.COURSE_TABLE+".id") + + // Duplicate Path + course = &models.Course{Base: models.Base{ID: "2"}, Title: "Course 2", Path: "/course-1"} + require.ErrorContains(t, dao.CreateCourse(ctx, course), "UNIQUE constraint failed: "+models.COURSE_TABLE+".path") + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_UpdateCourse(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + originalCourse := &models.Course{Title: "Course 1", Path: "/course-1", Available: true, CardPath: "/course-1/card-1"} + require.NoError(t, dao.CreateCourse(ctx, originalCourse)) + + time.Sleep(1 * time.Millisecond) + + newCourse := &models.Course{ + Base: originalCourse.Base, + Title: "Course 2", // Immutable + Path: "/course-2", // Immutable + Available: false, // Mutable + CardPath: "/course-2/card-2", // Mutable + } + require.NoError(t, dao.UpdateCourse(ctx, newCourse)) + + courseResult := &models.Course{Base: models.Base{ID: originalCourse.ID}} + require.NoError(t, dao.GetById(ctx, courseResult)) + require.Equal(t, originalCourse.ID, courseResult.ID) // No change + require.Equal(t, originalCourse.Title, courseResult.Title) // No change + require.Equal(t, originalCourse.Path, courseResult.Path) // No change + require.True(t, courseResult.CreatedAt.Equal(originalCourse.CreatedAt)) // No change + require.False(t, courseResult.Available) // Changed + require.Equal(t, newCourse.CardPath, courseResult.CardPath) // Changed + require.False(t, courseResult.UpdatedAt.Equal(originalCourse.UpdatedAt)) // Changed + }) + + t.Run("invalid", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + // Empty ID + course.ID = "" + require.ErrorIs(t, dao.UpdateCourse(ctx, course), utils.ErrInvalidId) + + // Nil Model + require.ErrorIs(t, dao.UpdateCourse(ctx, nil), utils.ErrNilPtr) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_ClassifyCoursePaths(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + courses := []*models.Course{} + for i := range 3 { + c := &models.Course{ + Title: fmt.Sprintf("Course %d", i), + Path: fmt.Sprintf("/course-%d", i), + } + require.NoError(t, dao.CreateCourse(ctx, c)) + courses = append(courses, c) + } + + path1 := "/" // ancestor + path2 := "/test" // none + path3 := courses[2].Path // course + path4 := courses[2].Path + "/test" // descendant + + result, err := dao.ClassifyCoursePaths(ctx, []string{path1, path2, path3, path4}) + require.Nil(t, err) + + require.Equal(t, types.PathClassificationAncestor, result[path1]) + require.Equal(t, types.PathClassificationNone, result[path2]) + require.Equal(t, types.PathClassificationCourse, result[path3]) + require.Equal(t, types.PathClassificationDescendant, result[path4]) + }) + + t.Run("no paths", func(t *testing.T) { + dao, ctx := setup(t) + + result, err := dao.ClassifyCoursePaths(ctx, []string{}) + require.Nil(t, err) + require.Empty(t, result) + }) + + t.Run("empty path", func(t *testing.T) { + dao, ctx := setup(t) + + result, err := dao.ClassifyCoursePaths(ctx, []string{"", "", ""}) + require.Nil(t, err) + require.Empty(t, result) + }) + + t.Run("db error", func(t *testing.T) { + dao, ctx := setup(t) + + _, err := dao.db.Exec("DROP TABLE IF EXISTS " + (&models.Course{}).Table()) + require.Nil(t, err) + + result, err := dao.ClassifyCoursePaths(ctx, []string{"/"}) + require.ErrorContains(t, err, "no such table: "+(&models.Course{}).Table()) + require.Empty(t, result) + }) +} diff --git a/dao/log.go b/dao/log.go new file mode 100644 index 0000000..3b6ccdb --- /dev/null +++ b/dao/log.go @@ -0,0 +1,39 @@ +package dao + +import ( + "context" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// WriteLog writes a new log +func (dao *DAO) WriteLog(ctx context.Context, log *models.Log) error { + if log == nil { + return utils.ErrNilPtr + } + + return dao.Create(ctx, log) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// List lists logs +// func (dao *LogDao) List(dbParams *database.DatabaseParams, tx *database.Tx) ([]*models.Log, error) { +// if dbParams == nil { +// dbParams = &database.DatabaseParams{} +// } + +// // Always override the order by to created_at +// dbParams.OrderBy = []string{dao.Table() + ".created_at DESC"} + +// // Default the columns if not specified +// if len(dbParams.Columns) == 0 { +// selectColumns, _ := tableColumnsOrPanic(models.Log{}, dao.Table()) +// dbParams.Columns = selectColumns +// } + +// return genericList(dao, dbParams, dao.scanRow, tx) +// } diff --git a/dao/log_test.go b/dao/log_test.go new file mode 100644 index 0000000..6aaa743 --- /dev/null +++ b/dao/log_test.go @@ -0,0 +1,164 @@ +package dao + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/appFs" + "github.com/geerew/off-course/utils/logger" + "github.com/spf13/afero" + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func setupLog(tb testing.TB) (*DAO, context.Context) { + tb.Helper() + + // Logger + var logs []*logger.Log + var logsMux sync.Mutex + logger, _, err := logger.InitLogger(&logger.BatchOptions{ + BatchSize: 1, + WriteFn: logger.TestWriteFn(&logs, &logsMux), + }) + require.NoError(tb, err, "Failed to initialize logger") + + // Filesystem + appFs := appFs.NewAppFs(afero.NewMemMapFs(), logger) + + // DB + dbManager, err := database.NewSqliteDBManager(&database.DatabaseConfig{ + IsDebug: false, + DataDir: "./oc_data", + AppFs: appFs, + InMemory: true, + }) + + require.NoError(tb, err) + require.NotNil(tb, dbManager) + + dao := &DAO{db: dbManager.LogsDb} + + return dao, context.Background() +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_WriteLog(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setupLog(t) + log := &models.Log{Data: map[string]any{}, Level: 0, Message: fmt.Sprintf("log %d", 1)} + require.NoError(t, dao.WriteLog(ctx, log)) + }) + + t.Run("nil", func(t *testing.T) { + dao, ctx := setupLog(t) + require.ErrorIs(t, dao.WriteLog(ctx, nil), utils.ErrNilPtr) + }) + +} + +// // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// func TestLog_List(t *testing.T) { +// t.Run("no entries", func(t *testing.T) { +// dao, _ := logSetup(t) + +// courses, err := dao.List(nil, nil) +// require.Nil(t, err) +// require.Zero(t, courses) +// }) + +// t.Run("found", func(t *testing.T) { +// dao, _ := logSetup(t) + +// for i := range 5 { +// require.Nil(t, dao.Write(&models.Log{Data: map[string]any{}, Level: 0, Message: fmt.Sprintf("log %d", i+1)}, nil)) +// time.Sleep(1 * time.Millisecond) +// } + +// result, err := dao.List(nil, nil) +// require.Nil(t, err) +// require.Len(t, result, 5) +// require.Equal(t, "log 5", result[0].Message) +// require.Equal(t, "log 1", result[4].Message) +// }) + +// t.Run("where", func(t *testing.T) { +// dao, _ := logSetup(t) + +// for i := range 5 { +// require.Nil(t, dao.Write(&models.Log{Data: map[string]any{}, Level: 0, Message: fmt.Sprintf("log %d", i+1)}, nil)) +// time.Sleep(1 * time.Millisecond) +// } + +// // ---------------------------- +// // EQUALS log 2 or log 3 +// // ---------------------------- +// result, err := dao.List( +// &database.DatabaseParams{Where: squirrel.Or{ +// squirrel.Eq{dao.Table() + ".message": "log 2"}, +// squirrel.Eq{dao.Table() + ".message": "log 3"}}}, +// nil) +// require.Nil(t, err) +// require.Len(t, result, 2) +// require.Equal(t, "log 3", result[0].Message) +// require.Equal(t, "log 2", result[1].Message) + +// // ---------------------------- +// // ERROR +// // ---------------------------- +// result, err = dao.List(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) +// require.ErrorContains(t, err, "syntax error") +// require.Nil(t, result) +// }) + +// t.Run("pagination", func(t *testing.T) { +// dao, _ := logSetup(t) + +// for i := range 17 { +// require.Nil(t, dao.Write(&models.Log{Data: map[string]any{}, Level: 0, Message: fmt.Sprintf("log %d", i+1)}, nil)) +// time.Sleep(1 * time.Millisecond) +// } + +// // ---------------------------- +// // Page 1 with 10 items +// // ---------------------------- +// p := pagination.New(1, 10) + +// result, err := dao.List(&database.DatabaseParams{Pagination: p}, nil) +// require.Nil(t, err) +// require.Len(t, result, 10) +// require.Equal(t, 17, p.TotalItems()) +// require.Equal(t, "log 17", result[0].Message) +// require.Equal(t, "log 8", result[9].Message) + +// // ---------------------------- +// // Page 2 with 7 items +// // ---------------------------- +// p = pagination.New(2, 10) + +// result, err = dao.List(&database.DatabaseParams{Pagination: p}, nil) +// require.Nil(t, err) +// require.Len(t, result, 7) +// require.Equal(t, 17, p.TotalItems()) +// require.Equal(t, "log 7", result[0].Message) +// require.Equal(t, "log 1", result[6].Message) +// }) + +// t.Run("db error", func(t *testing.T) { +// dao, db := logSetup(t) + +// _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) +// require.Nil(t, err) + +// _, err = dao.List(nil, nil) +// require.ErrorContains(t, err, "no such table: "+dao.Table()) +// }) +// } diff --git a/dao/param.go b/dao/param.go new file mode 100644 index 0000000..06fa65b --- /dev/null +++ b/dao/param.go @@ -0,0 +1,52 @@ +package dao + +import ( + "context" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CreateParam creates a parameter +func (dao *DAO) CreateParam(ctx context.Context, param *models.Param) error { + if param == nil { + return utils.ErrNilPtr + } + + return dao.Create(ctx, param) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// GetParam gets a parameter by its key +func (dao *DAO) GetParamByKey(ctx context.Context, param *models.Param) error { + if param == nil { + return utils.ErrNilPtr + } + + if param.Key == "" { + return utils.ErrInvalidKey + } + + options := &database.Options{ + Where: squirrel.Eq{param.Table() + ".key": param.Key}, + } + + return dao.Get(ctx, param, options) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// UpdateParam updates a parameter +func (dao *DAO) UpdateParam(ctx context.Context, param *models.Param) error { + if param == nil { + return utils.ErrNilPtr + } + + _, err := dao.Update(ctx, param) + return err +} diff --git a/dao/param_test.go b/dao/param_test.go new file mode 100644 index 0000000..86f0831 --- /dev/null +++ b/dao/param_test.go @@ -0,0 +1,195 @@ +package dao + +import ( + "testing" + "time" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_CreateParam(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + param := &models.Param{Key: "param 1", Value: "value 1"} + require.NoError(t, dao.CreateParam(ctx, param)) + }) + + t.Run("nil", func(t *testing.T) { + dao, ctx := setup(t) + require.ErrorIs(t, dao.CreateParam(ctx, nil), utils.ErrNilPtr) + }) + + t.Run("duplicate", func(t *testing.T) { + dao, ctx := setup(t) + + param := &models.Param{Key: "param 1", Value: "value 1"} + require.NoError(t, dao.CreateParam(ctx, param)) + + require.ErrorContains(t, dao.CreateParam(ctx, param), "UNIQUE constraint failed: "+models.PARAM_TABLE+".key") + }) +} + +func Test_GetParamByKey(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + param := &models.Param{Key: "param 1", Value: "value 1"} + require.NoError(t, dao.CreateParam(ctx, param)) + + paramResult := &models.Param{Key: param.Key} + require.NoError(t, dao.GetParamByKey(ctx, paramResult)) + require.Equal(t, param.ID, paramResult.ID) + }) + + t.Run("invalid", func(t *testing.T) { + dao, ctx := setup(t) + + // Nil model + require.ErrorIs(t, dao.GetParamByKey(ctx, nil), utils.ErrNilPtr) + + // Invalid key + param := &models.Param{Base: models.Base{ID: "1234"}} + require.ErrorIs(t, dao.GetParamByKey(ctx, param), utils.ErrInvalidKey) + }) + + t.Run("duplicate", func(t *testing.T) { + dao, ctx := setup(t) + + param := &models.Param{Key: "param 1", Value: "value 1"} + require.NoError(t, dao.CreateParam(ctx, param)) + + require.ErrorContains(t, dao.CreateParam(ctx, param), "UNIQUE constraint failed: "+models.PARAM_TABLE+".key") + }) +} + +func Test_UpdateParam(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + originalParam := &models.Param{Key: "param 1", Value: "value 1"} + require.NoError(t, dao.CreateParam(ctx, originalParam)) + + time.Sleep(1 * time.Millisecond) + + newParam := &models.Param{ + Base: originalParam.Base, + Key: "new key", // Immutable + Value: "new value", // Mutable + } + require.NoError(t, dao.UpdateParam(ctx, newParam)) + + paramResult := &models.Param{Base: models.Base{ID: originalParam.ID}} + require.NoError(t, dao.GetById(ctx, paramResult)) + require.Equal(t, originalParam.ID, paramResult.ID) // No change + require.Equal(t, originalParam.Key, paramResult.Key) // No change + require.True(t, paramResult.CreatedAt.Equal(originalParam.CreatedAt)) // No change + require.Equal(t, newParam.Value, paramResult.Value) // Changed + require.False(t, paramResult.UpdatedAt.Equal(originalParam.UpdatedAt)) // Changed + }) + + t.Run("invalid", func(t *testing.T) { + dao, ctx := setup(t) + + param := &models.Param{Key: "param 1", Value: "value 1"} + require.NoError(t, dao.CreateParam(ctx, param)) + + // Empty ID + param.ID = "" + require.ErrorIs(t, dao.UpdateParam(ctx, param), utils.ErrInvalidId) + + // Nil Model + require.ErrorIs(t, dao.UpdateParam(ctx, nil), utils.ErrNilPtr) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// func TestParam_Get(t *testing.T) { +// t.Run("found", func(t *testing.T) { +// dao, _ := paramSetup(t) + +// p, err := dao.Get("hasAdmin", nil) +// require.Nil(t, err) +// require.False(t, cast.ToBool(p.Value)) +// require.False(t, p.CreatedAt.IsZero()) +// require.False(t, p.UpdatedAt.IsZero()) +// }) + +// t.Run("not found", func(t *testing.T) { +// dao, _ := paramSetup(t) + +// p, err := dao.Get("test", nil) +// require.ErrorIs(t, err, sql.ErrNoRows) +// require.Nil(t, p) +// }) + +// t.Run("empty id", func(t *testing.T) { +// dao, _ := paramSetup(t) + +// p, err := dao.Get("", nil) +// require.ErrorIs(t, err, sql.ErrNoRows) +// require.Nil(t, p) +// }) + +// t.Run("db error", func(t *testing.T) { +// dao, db := paramSetup(t) + +// _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) +// require.Nil(t, err) + +// _, err = dao.Get("1234", nil) +// require.ErrorContains(t, err, "no such table: "+dao.Table()) +// }) +// } + +// // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// func TestParam_Update(t *testing.T) { +// t.Run("hasAdmin", func(t *testing.T) { +// dao, _ := paramSetup(t) + +// p, err := dao.Get("hasAdmin", nil) +// require.Nil(t, err) +// require.False(t, cast.ToBool(p.Value)) +// require.False(t, p.CreatedAt.IsZero()) +// require.False(t, p.UpdatedAt.IsZero()) + +// // Set to true +// p.Value = "true" +// require.Nil(t, dao.Update(p, nil)) + +// p2, err := dao.Get("hasAdmin", nil) +// require.Nil(t, err) +// require.True(t, cast.ToBool(p2.Value)) +// require.Equal(t, p.CreatedAt, p2.CreatedAt) +// require.NotEqual(t, p.UpdatedAt, p2.UpdatedAt) +// }) + +// t.Run("empty id", func(t *testing.T) { +// dao, _ := paramSetup(t) + +// err := dao.Update(&models.Param{}, nil) +// require.ErrorIs(t, err, ErrEmptyId) +// }) + +// t.Run("invalid id", func(t *testing.T) { +// dao, _ := paramSetup(t) + +// p := &models.Param{} +// p.ID = "1234" +// require.Nil(t, dao.Update(p, nil)) +// }) + +// t.Run("db error", func(t *testing.T) { +// dao, db := paramSetup(t) + +// _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) +// require.Nil(t, err) + +// _, err = dao.Get("1234", nil) +// require.ErrorContains(t, err, "no such table: "+dao.Table()) +// }) +// } diff --git a/dao/scan.go b/dao/scan.go new file mode 100644 index 0000000..1a9e2ff --- /dev/null +++ b/dao/scan.go @@ -0,0 +1,55 @@ +package dao + +import ( + "context" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/types" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CreateScan creates a scan +func (dao *DAO) CreateScan(ctx context.Context, scan *models.Scan) error { + if scan == nil { + return utils.ErrNilPtr + } + + // A scan should always be in the waiting state when created + if !scan.Status.IsWaiting() { + scan.Status.SetWaiting() + } + + return dao.Create(ctx, scan) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// UpdateScan updates a scan +func (dao *DAO) UpdateScan(ctx context.Context, scan *models.Scan) error { + if scan == nil { + return utils.ErrNilPtr + } + + _, err := dao.Update(ctx, scan) + return err +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Next gets the next scan whose status is `waiting“ based upon the created_at column +func (dao *DAO) NextWaitingScan(ctx context.Context, model models.Modeler) error { + if model == nil { + return utils.ErrNilPtr + } + + options := &database.Options{ + Where: squirrel.Eq{model.Table() + ".status": types.ScanStatusWaiting}, + OrderBy: []string{model.Table() + ".created_at ASC"}, + } + + return dao.Get(ctx, model, options) +} diff --git a/dao/scan_test.go b/dao/scan_test.go new file mode 100644 index 0000000..eb1eb04 --- /dev/null +++ b/dao/scan_test.go @@ -0,0 +1,163 @@ +package dao + +import ( + "database/sql" + "fmt" + "testing" + "time" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/types" + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_CreateScan(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1", Available: true, CardPath: "/course-1/card-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID} + require.Nil(t, dao.CreateScan(ctx, scan)) + }) + + t.Run("nil pointer", func(t *testing.T) { + dao, ctx := setup(t) + + require.ErrorIs(t, dao.CreateScan(ctx, nil), utils.ErrNilPtr) + }) + + t.Run("invalid course ID", func(t *testing.T) { + dao, ctx := setup(t) + + scan := &models.Scan{CourseID: "invalid"} + require.ErrorContains(t, dao.CreateScan(ctx, scan), "FOREIGN KEY constraint failed") + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_Update(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + originalScan := &models.Scan{CourseID: course.ID} + require.NoError(t, dao.CreateScan(ctx, originalScan)) + + newS := &models.Scan{ + Base: originalScan.Base, + CourseID: "1234", // Immutable + Status: types.NewScanStatusProcessing(), // Mutable + } + + time.Sleep(1 * time.Millisecond) + require.NoError(t, dao.UpdateScan(ctx, newS)) + + scanResult := &models.Scan{Base: models.Base{ID: originalScan.ID}} + require.Nil(t, dao.GetById(ctx, scanResult)) + require.Equal(t, originalScan.ID, scanResult.ID) // No change + require.Equal(t, originalScan.CourseID, scanResult.CourseID) // No change + require.True(t, scanResult.CreatedAt.Equal(originalScan.CreatedAt)) // No change + require.False(t, scanResult.Status.IsWaiting()) // Changed + require.False(t, scanResult.UpdatedAt.Equal(originalScan.UpdatedAt)) // Changed + }) + + t.Run("invalid", func(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID} + require.NoError(t, dao.CreateScan(ctx, scan)) + + // Empty ID + scan.ID = "" + require.ErrorIs(t, dao.UpdateScan(ctx, scan), utils.ErrInvalidId) + + // Nil Model + require.ErrorIs(t, dao.UpdateScan(ctx, nil), utils.ErrNilPtr) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_NextWaitingScan(t *testing.T) { + t.Run("first", func(t *testing.T) { + dao, ctx := setup(t) + + scans := []*models.Scan{} + for i := range 3 { + course := &models.Course{ + Title: fmt.Sprintf("Course %d", i), + Path: fmt.Sprintf("/course-%d", i), + } + require.NoError(t, dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID} + require.NoError(t, dao.CreateScan(ctx, scan)) + scans = append(scans, scan) + + time.Sleep(1 * time.Millisecond) + } + + scanResult := &models.Scan{} + require.NoError(t, dao.NextWaitingScan(ctx, scanResult)) + require.Equal(t, scans[0].ID, scanResult.ID) + }) + + t.Run("next", func(t *testing.T) { + dao, ctx := setup(t) + + scans := []*models.Scan{} + for i := range 3 { + course := &models.Course{ + Title: fmt.Sprintf("Course %d", i), + Path: fmt.Sprintf("/course-%d", i), + } + require.NoError(t, dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID} + require.NoError(t, dao.CreateScan(ctx, scan)) + scans = append(scans, scan) + + time.Sleep(1 * time.Millisecond) + } + + scans[0].Status = types.NewScanStatusProcessing() + require.NoError(t, dao.UpdateScan(ctx, scans[0])) + + scanResult := &models.Scan{} + require.NoError(t, dao.NextWaitingScan(ctx, scanResult)) + require.Equal(t, scans[1].ID, scanResult.ID) + }) + + t.Run("empty", func(t *testing.T) { + dao, ctx := setup(t) + require.ErrorIs(t, dao.NextWaitingScan(ctx, &models.Scan{}), sql.ErrNoRows) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_ScanDeleteCascade(t *testing.T) { + dao, ctx := setup(t) + + course := &models.Course{Title: "Course", Path: "/course"} + require.NoError(t, dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID} + require.NoError(t, dao.Create(ctx, scan)) + + require.Nil(t, dao.Delete(ctx, course, nil)) + + err := dao.GetById(ctx, scan) + require.ErrorIs(t, err, sql.ErrNoRows) +} diff --git a/dao/tag.go b/dao/tag.go new file mode 100644 index 0000000..b92f9d4 --- /dev/null +++ b/dao/tag.go @@ -0,0 +1,151 @@ +package dao + +import ( + "context" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CreateTag creates a tag +func (dao *DAO) CreateTag(ctx context.Context, tag *models.Tag) error { + if tag == nil { + return utils.ErrNilPtr + } + + return dao.Create(ctx, tag) + +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// UpdateTag updates a tag +func (dao *DAO) UpdateTag(ctx context.Context, tag *models.Tag) error { + if tag == nil { + return utils.ErrNilPtr + } + + _, err := dao.Update(ctx, tag) + return err +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// // Get gets a tag with the given ID or name +// // +// // CourseTags can be included by setting the `IncludeRelations` field in the dbParams. The coutseTags +// // can then be ordered by setting the `OrderBy` field in the dbParams, specifically referencing +// // courses_tags.[column] +// func (dao *TagDao) Get(id string, byName bool, dbParams *database.DatabaseParams, tx *database.Tx) (*models.Tag, error) { +// selectColumns, _ := tableColumnsOrPanic(models.Tag{}, dao.Table()) + +// tagDbParams := &database.DatabaseParams{ +// Columns: selectColumns, +// } + +// if byName { +// if dbParams != nil && dbParams.CaseInsensitive { +// tagDbParams.Where = squirrel.Eq{dao.Table() + ".tag COLLATE NOCASE": id} +// } else { +// tagDbParams.Where = squirrel.Eq{dao.Table() + ".tag": id} +// } +// } else { +// tagDbParams.Where = squirrel.Eq{dao.Table() + ".id": id} +// } + +// tag, err := genericGet(dao, tagDbParams, dao.scanRow, tx) +// if err != nil { +// return nil, err +// } + +// // Get the course tags +// courseTagDao := NewCourseTagDao(dao.db) +// if dbParams != nil && slices.Contains(dbParams.IncludeRelations, courseTagDao.Table()) { +// _, orderByColumns := tableColumnsOrPanic(models.CourseTag{}, courseTagDao.Table()) + +// courseTagDbParams := &database.DatabaseParams{ +// OrderBy: genericProcessOrderBy(dbParams.OrderBy, orderByColumns, courseTagDao, true), +// Where: squirrel.Eq{"tag_id": id}, +// } + +// // Get the course_tags +// courseTags, err := courseTagDao.List(courseTagDbParams, tx) +// if err != nil { +// return nil, err +// } + +// tag.CourseTags = courseTags +// } + +// return tag, nil +// } + +// // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// // List lists tags +// // +// // CourseTags can be included by setting the `IncludeRelations` field in the dbParams. The coutseTags +// // can then be ordered by setting the `OrderBy` field in the dbParams, specifically referencing +// // courses_tags.[column] +// func (dao *TagDao) List(dbParams *database.DatabaseParams, tx *database.Tx) ([]*models.Tag, error) { +// if dbParams == nil { +// dbParams = &database.DatabaseParams{} +// } + +// selectColumns, orderByColumns := tableColumnsOrPanic(models.Tag{}, dao.Table()) + +// // Backup the original order by then remove invalid orderBy columns +// origOrderBy := dbParams.OrderBy +// dbParams.OrderBy = genericProcessOrderBy(dbParams.OrderBy, orderByColumns, dao, false) + +// // Default the columns if not specified +// if len(dbParams.Columns) == 0 { +// dbParams.Columns = selectColumns +// } + +// tags, err := genericList(dao, dbParams, dao.scanRow, tx) +// if err != nil { +// return nil, err +// } + +// // Get the course_tags +// courseTagDao := NewCourseTagDao(dao.db) +// if len(tags) > 0 && slices.Contains(dbParams.IncludeRelations, courseTagDao.Table()) { +// // Get the tag IDs +// tagIds := []string{} +// for _, t := range tags { +// tagIds = append(tagIds, t.ID) +// } + +// _, orderByColumns := tableColumnsOrPanic(models.CourseTag{}, courseTagDao.Table()) + +// // Reduce the order by clause to only include columns specific to the course_tags table +// reducedOrderBy := genericProcessOrderBy(origOrderBy, orderByColumns, courseTagDao, true) + +// dbParams = &database.DatabaseParams{ +// OrderBy: reducedOrderBy, +// Where: squirrel.Eq{"tag_id": tagIds}, +// } + +// // Get the course_tags +// courseTags, err := courseTagDao.List(dbParams, tx) +// if err != nil { +// return nil, err +// } + +// // Map the course_tags to the tags +// tagMap := map[string][]*models.CourseTag{} +// for _, ct := range courseTags { +// tagMap[ct.TagId] = append(tagMap[ct.TagId], ct) +// } + +// // Assign the course_tags to the tags +// for _, t := range tags { +// t.CourseTags = tagMap[t.ID] +// } +// } + +// return tags, nil +// } diff --git a/dao/tag_test.go b/dao/tag_test.go new file mode 100644 index 0000000..750c5b8 --- /dev/null +++ b/dao/tag_test.go @@ -0,0 +1,80 @@ +package dao + +import ( + "testing" + "time" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_CreateTag(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + tag := &models.Tag{Tag: "Tag 1"} + require.NoError(t, dao.CreateTag(ctx, tag)) + }) + + t.Run("nil", func(t *testing.T) { + dao, ctx := setup(t) + require.ErrorIs(t, dao.CreateTag(ctx, nil), utils.ErrNilPtr) + }) + + t.Run("duplicate", func(t *testing.T) { + dao, ctx := setup(t) + + tag := &models.Tag{Base: models.Base{ID: "1"}, Tag: "Tag 1"} + require.NoError(t, dao.CreateTag(ctx, tag)) + + // Duplicate ID + tag = &models.Tag{Base: models.Base{ID: "1"}, Tag: "Tag 2"} + require.ErrorContains(t, dao.CreateTag(ctx, tag), "UNIQUE constraint failed: tags.id") + + // Duplicate tag + tag = &models.Tag{Base: models.Base{ID: "2"}, Tag: "Tag 1"} + require.ErrorContains(t, dao.CreateTag(ctx, tag), "UNIQUE constraint failed: tags.tag") + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_UpdateTag(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + + originalTag := &models.Tag{Tag: "Tag 1"} + require.Nil(t, dao.CreateTag(ctx, originalTag)) + + time.Sleep(1 * time.Millisecond) + + newTag := &models.Tag{ + Base: originalTag.Base, + Tag: "Tag 2", // Mutable + } + require.NoError(t, dao.UpdateTag(ctx, newTag)) + + tagResult := &models.Tag{Base: models.Base{ID: originalTag.ID}} + require.NoError(t, dao.GetById(ctx, tagResult)) + require.Equal(t, originalTag.ID, tagResult.ID) // No change + require.True(t, tagResult.CreatedAt.Equal(originalTag.CreatedAt)) // No change + require.False(t, tagResult.UpdatedAt.Equal(originalTag.UpdatedAt)) // Changed + require.Equal(t, newTag.Tag, tagResult.Tag) // Changed + }) + + t.Run("invalid", func(t *testing.T) { + dao, ctx := setup(t) + + tag := &models.Tag{Tag: "Tag 1"} + require.NoError(t, dao.CreateTag(ctx, tag)) + + // Empty ID + tag.ID = "" + require.ErrorIs(t, dao.UpdateTag(ctx, tag), utils.ErrInvalidId) + + // Nil Model + require.ErrorIs(t, dao.UpdateTag(ctx, nil), utils.ErrNilPtr) + }) +} diff --git a/dao/user.go b/dao/user.go new file mode 100644 index 0000000..dbfdfdc --- /dev/null +++ b/dao/user.go @@ -0,0 +1,19 @@ +package dao + +import ( + "context" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CreateUser creates a user +func (dao *DAO) CreateUser(ctx context.Context, user *models.User) error { + if user == nil { + return utils.ErrNilPtr + } + + return dao.Create(ctx, user) +} diff --git a/dao/user_test.go b/dao/user_test.go new file mode 100644 index 0000000..9e808a4 --- /dev/null +++ b/dao/user_test.go @@ -0,0 +1,25 @@ +package dao + +import ( + "testing" + + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/types" + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_CreateUser(t *testing.T) { + t.Run("success", func(t *testing.T) { + dao, ctx := setup(t) + user := &models.User{Username: "admin", PasswordHash: "password", Role: types.UserRoleAdmin} + require.NoError(t, dao.CreateUser(ctx, user)) + }) + + t.Run("nil", func(t *testing.T) { + dao, ctx := setup(t) + require.ErrorIs(t, dao.CreateUser(ctx, nil), utils.ErrNilPtr) + }) +} diff --git a/daos/asset.go b/daos/asset.go deleted file mode 100644 index 42e2b02..0000000 --- a/daos/asset.go +++ /dev/null @@ -1,383 +0,0 @@ -package daos - -import ( - "database/sql" - "slices" - "time" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// AssetDao is the data access object for assets -type AssetDao struct { - db database.Database - table string -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NewAssetDao returns a new AssetDao -func NewAssetDao(db database.Database) *AssetDao { - return &AssetDao{ - db: db, - table: "assets", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Table returns the table name -func (dao *AssetDao) Table() string { - return dao.table -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Count returns the number of assets -func (dao *AssetDao) Count(params *database.DatabaseParams, tx *database.Tx) (int, error) { - generic := NewGenericDao(dao.db, dao) - return generic.Count(params, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Create inserts a new asset -func (dao *AssetDao) Create(a *models.Asset, tx *database.Tx) error { - if a.Prefix.Valid && a.Prefix.Int16 < 0 { - return ErrInvalidPrefix - } - - if a.ID == "" { - a.RefreshId() - } - - a.RefreshCreatedAt() - a.RefreshUpdatedAt() - - query, args, _ := squirrel. - StatementBuilder. - Insert(dao.Table()). - SetMap(dao.data(a)). - ToSql() - - execFn := dao.db.Exec - if tx != nil { - execFn = tx.Exec - } - - _, err := execFn(query, args...) - - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Get selects an asset with the given ID. -// -// `dbparams` can be used to order the attachments -// -// `tx` allows for the function to be run within a transaction -func (dao *AssetDao) Get(id string, dbParams *database.DatabaseParams, tx *database.Tx) (*models.Asset, error) { - generic := NewGenericDao(dao.db, dao) - - assetDbParams := &database.DatabaseParams{ - Columns: dao.columns(), - Where: squirrel.Eq{dao.Table() + ".id": id}, - } - - row, err := generic.Get(assetDbParams, tx) - if err != nil { - return nil, err - } - - asset, err := dao.scanRow(row) - if err != nil { - return nil, err - } - - // Get the attachments - attachmentDao := NewAttachmentDao(dao.db) - if dbParams != nil && slices.Contains(dbParams.IncludeRelations, attachmentDao.Table()) { - // Set the DB params - attachmentDbParams := &database.DatabaseParams{ - OrderBy: attachmentDao.ProcessOrderBy(dbParams.OrderBy, true), - Where: squirrel.Eq{"asset_id": asset.ID}, - } - - attachments, err := attachmentDao.List(attachmentDbParams, tx) - if err != nil { - return nil, err - } - - asset.Attachments = attachments - } - - return asset, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// List selects assets -// -// `tx` allows for the function to be run within a transaction -func (dao *AssetDao) List(dbParams *database.DatabaseParams, tx *database.Tx) ([]*models.Asset, error) { - generic := NewGenericDao(dao.db, dao) - - if dbParams == nil { - dbParams = &database.DatabaseParams{} - } - - origOrderBy := dbParams.OrderBy - dbParams.OrderBy = dao.ProcessOrderBy(dbParams.OrderBy, false) - - // Default the columns if not specified - if len(dbParams.Columns) == 0 { - dbParams.Columns = dao.columns() - } - - rows, err := generic.List(dbParams, tx) - if err != nil { - return nil, err - } - defer rows.Close() - - var assets []*models.Asset - assetIds := []string{} - - for rows.Next() { - a, err := dao.scanRow(rows) - if err != nil { - return nil, err - } - - assets = append(assets, a) - assetIds = append(assetIds, a.ID) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - // Get the attachments - attachmentDao := NewAttachmentDao(dao.db) - if len(assets) > 0 && slices.Contains(dbParams.IncludeRelations, attachmentDao.Table()) { - - // Reduce the order by clause to only include columns specific to the attachments table - reducedOrderBy := attachmentDao.ProcessOrderBy(origOrderBy, true) - - dbParams = &database.DatabaseParams{ - OrderBy: reducedOrderBy, - Where: squirrel.Eq{"asset_id": assetIds}, - } - - // Get the attachments - attachments, err := attachmentDao.List(dbParams, tx) - if err != nil { - return nil, err - } - - // Store in a map for easy lookup - attachmentsMap := map[string][]*models.Attachment{} - for _, a := range attachments { - attachmentsMap[a.AssetID] = append(attachmentsMap[a.AssetID], a) - } - - // Add attachments to its asset - for _, a := range assets { - a.Attachments = attachmentsMap[a.ID] - } - } - - return assets, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Update updates an asset -// -// Note: Only `title`, `prefix`, `chapter`, `type`, and `path` can be updated -func (dao *AssetDao) Update(asset *models.Asset, tx *database.Tx) error { - if asset.ID == "" { - return ErrEmptyId - } - - asset.RefreshUpdatedAt() - - query, args, _ := squirrel. - StatementBuilder. - Update(dao.Table()). - Set("title", NilStr(asset.Title)). - Set("prefix", asset.Prefix). - Set("chapter", NilStr(asset.Chapter)). - Set("type", NilStr(asset.Type.String())). - Set("path", NilStr(asset.Path)). - Set("updated_at", FormatTime(asset.UpdatedAt)). - Where("id = ?", asset.ID). - ToSql() - - execFn := dao.db.Exec - if tx != nil { - execFn = tx.Exec - } - - _, err := execFn(query, args...) - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Delete deletes an asset based upon the where clause -// -// `tx` allows for the function to be run within a transaction -func (dao *AssetDao) Delete(dbParams *database.DatabaseParams, tx *database.Tx) error { - if dbParams == nil || dbParams.Where == nil { - return ErrMissingWhere - } - - generic := NewGenericDao(dao.db, dao) - return generic.Delete(dbParams, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// ProcessOrderBy takes an array of strings representing orderBy clauses and returns a processed -// version of this array -// -// It will creates a new list of valid table columns based upon columns() for the current -// DAO -func (dao *AssetDao) ProcessOrderBy(orderBy []string, explicit bool) []string { - if len(orderBy) == 0 { - return orderBy - } - - generic := NewGenericDao(dao.db, dao) - return generic.ProcessOrderBy(orderBy, dao.columns(), explicit) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Internal -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// countSelect returns the default count select builder -func (dao *AssetDao) countSelect() squirrel.SelectBuilder { - apDao := NewAssetProgressDao(dao.db) - - return squirrel. - StatementBuilder. - PlaceholderFormat(squirrel.Question). - Select(""). - From(dao.Table()). - LeftJoin(apDao.Table() + " ON " + dao.Table() + ".id = " + apDao.Table() + ".asset_id"). - RemoveColumns() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// baseSelect returns the default select builder -// -// It performs 1 left join -// - assets progress table to get `video_pos`, `completed` and `completed_at` -// -// Note: The columns are removed, so you must specify the columns with `.Columns(...)` when using -// this select builder -func (dao *AssetDao) baseSelect() squirrel.SelectBuilder { - return dao.countSelect() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// columns returns the columns to select -func (dao *AssetDao) columns() []string { - apDao := NewAssetProgressDao(dao.db) - - return []string{ - dao.Table() + ".*", - apDao.Table() + ".video_pos", - apDao.Table() + ".completed", - apDao.Table() + ".completed_at", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// data generates a map of key/values for an asset -func (dao *AssetDao) data(a *models.Asset) map[string]any { - return map[string]any{ - "id": a.ID, - "course_id": NilStr(a.CourseID), - "title": NilStr(a.Title), - "prefix": a.Prefix, - "chapter": NilStr(a.Chapter), - "type": NilStr(a.Type.String()), - "path": NilStr(a.Path), - "hash": NilStr(a.Hash), - "created_at": FormatTime(a.CreatedAt), - "updated_at": FormatTime(a.UpdatedAt), - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// scanRow scans an asset row -func (dao *AssetDao) scanRow(scannable Scannable) (*models.Asset, error) { - var a models.Asset - - // Nullable fields - var chapter sql.NullString - var videoPos sql.NullInt16 - var completed sql.NullBool - var createdAt string - var updatedAt string - var completedAt sql.NullString - - err := scannable.Scan( - &a.ID, - &a.CourseID, - &a.Title, - &a.Prefix, - &chapter, - &a.Type, - &a.Path, - &a.Hash, - &createdAt, - &updatedAt, - - // Asset progress - &videoPos, - &completed, - &completedAt, - ) - - if err != nil { - return nil, err - } - - if chapter.Valid { - a.Chapter = chapter.String - } - - if a.CreatedAt, err = ParseTime(createdAt); err != nil { - return nil, err - } - - if a.UpdatedAt, err = ParseTime(updatedAt); err != nil { - return nil, err - } - - if completedAt.Valid { - if a.CompletedAt, err = ParseTime(completedAt.String); err != nil { - return nil, err - } - } else { - a.CompletedAt = time.Time{} - } - - a.VideoPos = int(videoPos.Int16) - a.Completed = completed.Bool - - return &a, nil -} diff --git a/daos/asset_progress.go b/daos/asset_progress.go deleted file mode 100644 index b3ebf46..0000000 --- a/daos/asset_progress.go +++ /dev/null @@ -1,289 +0,0 @@ -package daos - -import ( - "time" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// AssetProgressDao is the data access object for assets progress -type AssetProgressDao struct { - db database.Database - table string -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NewAssetProgressDao returns a new AssetProgressDao -func NewAssetProgressDao(db database.Database) *AssetProgressDao { - return &AssetProgressDao{ - db: db, - table: "assets_progress", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Table returns the table name -func (dao *AssetProgressDao) Table() string { - return dao.table -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Create inserts a new asset progress, then refreshes the course progress -// -// If `tx` is nil, the function will create a new transaction, else it will use the current -// transaction -func (dao *AssetProgressDao) Create(ap *models.AssetProgress, tx *database.Tx) error { - if tx == nil { - return dao.db.RunInTransaction(func(tx *database.Tx) error { - return dao.create(ap, tx) - }) - } else { - return dao.create(ap, tx) - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Get selects an asset progress with the given asset ID -// -// `tx` allows for the function to be run within a transaction -func (dao *AssetProgressDao) Get(assetId string, tx *database.Tx) (*models.AssetProgress, error) { - generic := NewGenericDao(dao.db, dao) - - dbParams := &database.DatabaseParams{ - Columns: dao.columns(), - Where: squirrel.Eq{dao.Table() + ".asset_id": assetId}, - } - - row, err := generic.Get(dbParams, tx) - if err != nil { - return nil, err - } - - cp, err := dao.scanRow(row) - if err != nil { - return nil, err - } - - return cp, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Update updates the `video_pos` (for video assets) and `completed`, then refreshes the course progress -// -// When `completed` is true, `completed_at` is set to the current time. When the `completed` -// is false, `completed_at` is set to null -// -// If `tx` is nil, the function will create a new transaction, else it will use the current -// transaction -func (dao *AssetProgressDao) Update(ap *models.AssetProgress, tx *database.Tx) error { - if tx == nil { - return dao.db.RunInTransaction(func(tx *database.Tx) error { - return dao.update(ap, tx) - }) - } else { - return dao.update(ap, tx) - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Internal -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// create inserts a new asset progress, then refreshes the course progress -// -// This function is used by Create() and always runs within a transaction -func (dao *AssetProgressDao) create(ap *models.AssetProgress, tx *database.Tx) error { - if ap.ID == "" { - ap.RefreshId() - } - - ap.RefreshCreatedAt() - ap.RefreshUpdatedAt() - - query, args, _ := squirrel. - StatementBuilder. - Insert(dao.Table()). - SetMap(dao.data(ap)). - ToSql() - - _, err := tx.Exec(query, args...) - - if err != nil { - return err - } - - // Refresh course progress - cpDao := NewCourseProgressDao(dao.db) - return cpDao.Refresh(ap.CourseID, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// update updates the asset progress, then refreshes the course progress -// -// This function is used by Update() and always runs within a transaction -func (dao *AssetProgressDao) update(ap *models.AssetProgress, tx *database.Tx) error { - if ap.AssetID == "" { - return ErrEmptyId - } - - if tx == nil { - return ErrNilTransaction - } - - // Normalize video position - if ap.VideoPos < 0 { - ap.VideoPos = 0 - } - - // Get the asset - assetDao := NewAssetDao(dao.db) - asset, err := assetDao.Get(ap.AssetID, nil, tx) - if err != nil { - return err - } - - // Return when nothing has changed - if asset.VideoPos == ap.VideoPos && asset.Completed == ap.Completed { - return nil - } - - // Set an id (if empty) - if ap.ID == "" { - ap.RefreshId() - } - - // Set course id (if empty) - if ap.CourseID == "" { - ap.CourseID = asset.CourseID - } - - ap.RefreshCreatedAt() - ap.RefreshUpdatedAt() - - if ap.Completed { - if !asset.CompletedAt.IsZero() { - ap.CompletedAt = asset.CompletedAt - } else { - ap.CompletedAt = time.Now() - } - } else { - ap.CompletedAt = time.Time{} - } - - // Update (or create if it doesn't exist) - query, args, _ := squirrel. - StatementBuilder. - Insert(dao.Table()). - SetMap(dao.data(ap)). - Suffix( - "ON CONFLICT (asset_id) DO UPDATE SET video_pos = ?, completed = ?, completed_at = ?, updated_at = ?", - ap.VideoPos, ap.Completed, FormatTime(ap.CompletedAt), FormatTime(ap.UpdatedAt), - ). - ToSql() - - _, err = tx.Exec(query, args...) - if err != nil { - return err - } - - // Refresh course progress - cpDao := NewCourseProgressDao(dao.db) - return cpDao.Refresh(ap.CourseID, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// countSelect returns the default count select builder -func (dao *AssetProgressDao) countSelect() squirrel.SelectBuilder { - return squirrel. - StatementBuilder. - PlaceholderFormat(squirrel.Question). - Select(""). - From(dao.Table()). - RemoveColumns() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// baseSelect returns the default select builder -// -// Note: The columns are removed, so you must specify the columns with `.Columns(...)` when using -// this select builder -func (dao *AssetProgressDao) baseSelect() squirrel.SelectBuilder { - return dao.countSelect() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// columns returns the columns to select -func (dao *AssetProgressDao) columns() []string { - return []string{ - dao.Table() + ".*", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// data generates a map of key/values for thane asset progress -func (dao *AssetProgressDao) data(ap *models.AssetProgress) map[string]any { - return map[string]any{ - "id": ap.ID, - "asset_id": NilStr(ap.AssetID), - "course_id": NilStr(ap.CourseID), - "video_pos": ap.VideoPos, - "completed": ap.Completed, - "completed_at": FormatTime(ap.CompletedAt), - "created_at": FormatTime(ap.CreatedAt), - "updated_at": FormatTime(ap.UpdatedAt), - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// scanRow scans an assets progress row -func (dao *AssetProgressDao) scanRow(scannable Scannable) (*models.AssetProgress, error) { - var ap models.AssetProgress - - var createdAt string - var updatedAt string - var completedAt string - - err := scannable.Scan( - &ap.ID, - &ap.AssetID, - &ap.CourseID, - &ap.VideoPos, - &ap.Completed, - &completedAt, - &createdAt, - &updatedAt, - ) - - if err != nil { - return nil, err - } - - if ap.CompletedAt, err = ParseTime(completedAt); err != nil { - return nil, err - } - - if ap.CreatedAt, err = ParseTime(createdAt); err != nil { - return nil, err - } - - if ap.UpdatedAt, err = ParseTime(updatedAt); err != nil { - return nil, err - } - - return &ap, nil -} diff --git a/daos/asset_progress_test.go b/daos/asset_progress_test.go deleted file mode 100644 index 067977e..0000000 --- a/daos/asset_progress_test.go +++ /dev/null @@ -1,320 +0,0 @@ -package daos - -import ( - "database/sql" - "fmt" - "testing" - - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func assetProgressSetup(t *testing.T) (*AssetProgressDao, database.Database) { - t.Helper() - - dbManager := setup(t) - apDao := NewAssetProgressDao(dbManager.DataDb) - return apDao, dbManager.DataDb -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAssetProgress_Create(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, db := assetProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - ap := &models.AssetProgress{ - AssetID: testData[0].Assets[0].ID, - CourseID: testData[0].ID, - } - - err := dao.Create(ap, nil) - require.Nil(t, err) - require.NotEmpty(t, ap.ID) - require.Equal(t, testData[0].Assets[0].ID, ap.AssetID) - require.Zero(t, ap.VideoPos) - require.False(t, ap.Completed) - require.True(t, ap.CompletedAt.IsZero()) - require.False(t, ap.CreatedAt.IsZero()) - require.False(t, ap.UpdatedAt.IsZero()) - }) - - t.Run("duplicate asset id", func(t *testing.T) { - dao, db := assetProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - ap := &models.AssetProgress{ - AssetID: testData[0].Assets[0].ID, - CourseID: testData[0].ID, - } - - // Create - require.Nil(t, dao.Create(ap, nil)) - - // Create (again) - require.ErrorContains(t, dao.Create(ap, nil), fmt.Sprintf("UNIQUE constraint failed: %s.asset_id", dao.Table())) - }) - - t.Run("constraint errors", func(t *testing.T) { - dao, db := assetProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - ap := &models.AssetProgress{} - - // Asset ID - require.ErrorContains(t, dao.Create(ap, nil), fmt.Sprintf("NOT NULL constraint failed: %s.asset_id", dao.Table())) - ap.AssetID = "" - - require.ErrorContains(t, dao.Create(ap, nil), fmt.Sprintf("NOT NULL constraint failed: %s.asset_id", dao.Table())) - ap.AssetID = "1234" - - // Course ID - require.ErrorContains(t, dao.Create(ap, nil), fmt.Sprintf("NOT NULL constraint failed: %s.course_id", dao.Table())) - ap.CourseID = "" - - require.ErrorContains(t, dao.Create(ap, nil), fmt.Sprintf("NOT NULL constraint failed: %s.course_id", dao.Table())) - ap.CourseID = "1234" - - // Invalid asset ID - require.ErrorContains(t, dao.Create(ap, nil), "FOREIGN KEY constraint failed") - ap.AssetID = testData[0].Assets[0].ID - - // Invalid course ID - require.ErrorContains(t, dao.Create(ap, nil), "FOREIGN KEY constraint failed") - ap.CourseID = testData[0].ID - - // Success - require.Nil(t, dao.Create(ap, nil)) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAssetProgress_Get(t *testing.T) { - t.Run("found", func(t *testing.T) { - dao, db := assetProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Assets(1).Build() - - for _, tc := range testData { - require.Nil(t, dao.Create(&models.AssetProgress{ - AssetID: tc.Assets[0].ID, - CourseID: tc.ID, - }, nil)) - } - - ap, err := dao.Get(testData[1].Assets[0].ID, nil) - require.Nil(t, err) - require.Equal(t, testData[1].Assets[0].ID, ap.AssetID) - }) - - t.Run("not found", func(t *testing.T) { - dao, _ := assetProgressSetup(t) - - ap, err := dao.Get("1234", nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, ap) - }) - - t.Run("empty id", func(t *testing.T) { - dao, _ := assetProgressSetup(t) - - ap, err := dao.Get("", nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, ap) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := assetProgressSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Get("1234", nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAssetProgress_Update(t *testing.T) { - t.Run("update", func(t *testing.T) { - dao, db := assetProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - // Create the asset progress - originalAp := &models.AssetProgress{ - AssetID: testData[0].Assets[0].ID, - CourseID: testData[0].ID, - } - require.Nil(t, dao.Create(originalAp, nil)) - - require.Zero(t, originalAp.VideoPos) - - cpDao := NewCourseProgressDao(db) - - // ---------------------------- - // Set to 50 - // ---------------------------- - originalAp.VideoPos = 50 - require.Nil(t, dao.Update(originalAp, nil)) - - updatedAp1, err := dao.Get(originalAp.AssetID, nil) - require.Nil(t, err) - require.Equal(t, 50, updatedAp1.VideoPos) - - // Ensure the course was set to started - cp1, err := cpDao.Get(testData[0].ID, nil) - require.Nil(t, err) - require.True(t, cp1.Started) - require.False(t, cp1.StartedAt.IsZero()) - - // ---------------------------- - // Set to -10 (should be set to 0) - // ---------------------------- - updatedAp1.VideoPos = -10 - require.Nil(t, dao.Update(updatedAp1, nil)) - - updatedAp2, err := dao.Get(updatedAp1.AssetID, nil) - require.Nil(t, err) - require.Zero(t, updatedAp2.VideoPos) - - // Ensure the course is not started - cp2, err := cpDao.Get(testData[0].ID, nil) - require.Nil(t, err) - require.False(t, cp2.Started) - require.True(t, cp2.StartedAt.IsZero()) - - // ---------------------------- - // Set completed - // ---------------------------- - updatedAp2.Completed = true - require.Nil(t, dao.Update(updatedAp2, nil)) - - updatedAp3, err := dao.Get(updatedAp2.AssetID, nil) - require.Nil(t, err) - require.Zero(t, updatedAp3.VideoPos) - require.True(t, updatedAp3.Completed) - require.False(t, updatedAp3.CompletedAt.IsZero()) - - // Ensure the course is started and completed - cp3, err := cpDao.Get(testData[0].ID, nil) - require.Nil(t, err) - require.True(t, cp3.Started) - require.False(t, cp3.StartedAt.IsZero()) - require.Equal(t, 100, cp3.Percent) - require.False(t, cp3.CompletedAt.IsZero()) - }) - - t.Run("empty id", func(t *testing.T) { - dao, db := assetProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - // Create the asset progress - ap := &models.AssetProgress{ - AssetID: testData[0].Assets[0].ID, - CourseID: testData[0].ID, - } - - require.Nil(t, dao.Create(ap, nil)) - - ap.AssetID = "" - - require.EqualError(t, dao.Update(ap, nil), "id cannot be empty") - }) - - t.Run("invalid asset id", func(t *testing.T) { - dao, db := assetProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - // Create the asset progress - ap := &models.AssetProgress{ - AssetID: testData[0].Assets[0].ID, - CourseID: testData[0].ID, - } - require.Nil(t, dao.Create(ap, nil)) - - ap.AssetID = "1234" - - require.ErrorIs(t, dao.Update(ap, nil), sql.ErrNoRows) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := assetProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - // Create the asset progress - ap := &models.AssetProgress{ - AssetID: testData[0].Assets[0].ID, - CourseID: testData[0].ID, - } - - require.Nil(t, dao.Create(ap, nil)) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - require.ErrorContains(t, dao.Update(ap, nil), "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAssetProgress_DeleteCascade(t *testing.T) { - t.Run("course", func(t *testing.T) { - dao, db := assetProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - // Create the asset progress - ap := &models.AssetProgress{ - AssetID: testData[0].Assets[0].ID, - CourseID: testData[0].ID, - } - - require.Nil(t, dao.Create(ap, nil)) - - // Delete the course - courseDao := NewCourseDao(db) - err := courseDao.Delete(&database.DatabaseParams{Where: map[string]interface{}{"id": testData[0].ID}}, nil) - require.Nil(t, err) - - // Check the asset progress was deleted - _, err = dao.Get(ap.ID, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - }) - - t.Run("asset", func(t *testing.T) { - dao, db := assetProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - // Create the asset progress - ap := &models.AssetProgress{ - AssetID: testData[0].Assets[0].ID, - CourseID: testData[0].ID, - } - - require.Nil(t, dao.Create(ap, nil)) - - // Delete the asset - assetDao := NewAssetDao(db) - err := assetDao.Delete(&database.DatabaseParams{Where: map[string]interface{}{"id": testData[0].Assets[0].ID}}, nil) - require.Nil(t, err) - - // Check the asset progress was deleted - _, err = dao.Get(ap.ID, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - }) -} diff --git a/daos/asset_test.go b/daos/asset_test.go deleted file mode 100644 index fc785d9..0000000 --- a/daos/asset_test.go +++ /dev/null @@ -1,630 +0,0 @@ -package daos - -import ( - "database/sql" - "fmt" - "math/rand" - "testing" - "time" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils/pagination" - "github.com/geerew/off-course/utils/security" - "github.com/geerew/off-course/utils/types" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func assetSetup(t *testing.T) (*AssetDao, database.Database) { - t.Helper() - - dbManager := setup(t) - assetDao := NewAssetDao(dbManager.DataDb) - return assetDao, dbManager.DataDb -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAsset_Count(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := assetSetup(t) - - count, err := dao.Count(nil, nil) - require.Nil(t, err) - require.Zero(t, count) - }) - - t.Run("entries", func(t *testing.T) { - dao, db := assetSetup(t) - - NewTestBuilder(t).Db(db).Courses(5).Assets(1).Build() - - count, err := dao.Count(nil, nil) - require.Nil(t, err) - require.Equal(t, count, 5) - }) - - t.Run("where", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Assets(2).Build() - - // ---------------------------- - // EQUALS ID - // ---------------------------- - count, err := dao.Count(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".id": testData[0].Assets[1].ID}}, nil) - require.Nil(t, err) - require.Equal(t, 1, count) - - // ---------------------------- - // NOT EQUALS ID - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.NotEq{dao.Table() + ".id": testData[0].Assets[1].ID}}, nil) - require.Nil(t, err) - require.Equal(t, 5, count) - - // ---------------------------- - // EQUALS COURSE_ID - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".course_id": testData[1].ID}}, nil) - require.Nil(t, err) - require.Equal(t, 2, count) - - // ---------------------------- - // ERROR - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) - require.ErrorContains(t, err, "syntax error") - require.Zero(t, count) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := assetSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Count(nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAsset_Create(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Courses(1).Assets(1).Build() - - // Create the course - courseDao := NewCourseDao(db) - require.Nil(t, courseDao.Create(testData[0].Course)) - - // Create the asset - err := dao.Create(testData[0].Assets[0], nil) - require.Nil(t, err) - - newA, err := dao.Get(testData[0].Assets[0].ID, nil, nil) - require.Nil(t, err) - require.Equal(t, testData[0].Assets[0].ID, newA.ID) - require.Equal(t, testData[0].Assets[0].CourseID, newA.CourseID) - require.Equal(t, testData[0].Assets[0].Title, newA.Title) - require.Equal(t, testData[0].Assets[0].Prefix, newA.Prefix) - require.Equal(t, testData[0].Assets[0].Chapter, newA.Chapter) - require.Equal(t, testData[0].Assets[0].Type, newA.Type) - require.Equal(t, testData[0].Assets[0].Path, newA.Path) - require.False(t, newA.CreatedAt.IsZero()) - require.False(t, newA.UpdatedAt.IsZero()) - - // Progress - require.Zero(t, newA.VideoPos) - require.False(t, newA.Completed) - require.True(t, newA.CompletedAt.IsZero()) - }) - - t.Run("duplicate paths", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - // Create the asset (again) - err := dao.Create(testData[0].Assets[0], nil) - require.ErrorContains(t, err, fmt.Sprintf("UNIQUE constraint failed: %s.path", dao.Table())) - }) - - t.Run("constraints", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - - // No course ID - asset := &models.Asset{} - require.ErrorContains(t, dao.Create(asset, nil), fmt.Sprintf("NOT NULL constraint failed: %s.course_id", dao.Table())) - asset.CourseID = "" - require.ErrorContains(t, dao.Create(asset, nil), fmt.Sprintf("NOT NULL constraint failed: %s.course_id", dao.Table())) - asset.CourseID = "1234" - - // No title - require.ErrorContains(t, dao.Create(asset, nil), fmt.Sprintf("NOT NULL constraint failed: %s.title", dao.Table())) - asset.Title = "" - require.ErrorContains(t, dao.Create(asset, nil), fmt.Sprintf("NOT NULL constraint failed: %s.title", dao.Table())) - asset.Title = "Course 1" - - // No/invalid prefix - require.ErrorContains(t, dao.Create(asset, nil), fmt.Sprintf("NOT NULL constraint failed: %s.prefix", dao.Table())) - asset.Prefix = sql.NullInt16{Int16: -1, Valid: true} - require.ErrorContains(t, dao.Create(asset, nil), "prefix must be greater than 0") - asset.Prefix = sql.NullInt16{Int16: 1, Valid: true} - - // No type - require.ErrorContains(t, dao.Create(asset, nil), fmt.Sprintf("NOT NULL constraint failed: %s.type", dao.Table())) - asset.Type = types.Asset{} - require.ErrorContains(t, dao.Create(asset, nil), fmt.Sprintf("NOT NULL constraint failed: %s.type", dao.Table())) - asset.Type = *types.NewAsset("mp4") - - // No path - require.ErrorContains(t, dao.Create(asset, nil), fmt.Sprintf("NOT NULL constraint failed: %s.path", dao.Table())) - asset.Path = "" - require.ErrorContains(t, dao.Create(asset, nil), fmt.Sprintf("NOT NULL constraint failed: %s.path", dao.Table())) - asset.Path = "/course 1/01 asset" - - // No hash - require.ErrorContains(t, dao.Create(asset, nil), fmt.Sprintf("NOT NULL constraint failed: %s.hash", dao.Table())) - asset.Hash = "" - require.ErrorContains(t, dao.Create(asset, nil), fmt.Sprintf("NOT NULL constraint failed: %s.hash", dao.Table())) - asset.Hash = "1234" - - // Invalid Course ID - require.ErrorContains(t, dao.Create(asset, nil), "FOREIGN KEY constraint failed") - - // Success - asset.CourseID = testData[0].ID - require.Nil(t, dao.Create(asset, nil)) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAsset_Get(t *testing.T) { - t.Run("found", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(2).Assets(1).Attachments(2).Build() - - a, err := dao.Get(testData[0].Assets[0].ID, nil, nil) - require.Nil(t, err) - require.Equal(t, testData[0].Assets[0].ID, a.ID) - require.Nil(t, a.Attachments) - - // ---------------------------- - // Progress - // ---------------------------- - apDao := NewAssetProgressDao(db) - - require.Zero(t, a.VideoPos) - require.False(t, a.Completed) - require.True(t, a.CompletedAt.IsZero()) - - // Set video pos - ap := &models.AssetProgress{ - AssetID: testData[0].Assets[0].ID, - CourseID: testData[0].ID, - VideoPos: 50, - } - - require.Nil(t, apDao.Update(ap, nil)) - - a, err = dao.Get(a.ID, nil, nil) - require.Nil(t, err) - require.Equal(t, 50, a.VideoPos) - require.False(t, a.Completed) - require.True(t, a.CompletedAt.IsZero()) - - // Set completed - ap.Completed = true - require.Nil(t, apDao.Update(ap, nil)) - - a, err = dao.Get(a.ID, nil, nil) - require.Nil(t, err) - require.Equal(t, 50, a.VideoPos) - require.True(t, a.Completed) - require.False(t, a.CompletedAt.IsZero()) - - // ---------------------------- - // Attachments - // ---------------------------- - a, err = dao.Get(testData[0].Assets[0].ID, &database.DatabaseParams{IncludeRelations: []string{NewAttachmentDao(dao.db).Table()}}, nil) - require.Nil(t, err) - require.Equal(t, testData[0].Assets[0].ID, a.ID) - - require.Len(t, a.Attachments, 2) - require.Equal(t, testData[0].Assets[0].Attachments[0].ID, a.Attachments[0].ID) - require.Equal(t, testData[0].Assets[0].Attachments[1].ID, a.Attachments[1].ID) - }) - - t.Run("orderby", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Assets(1).Attachments(2).Build() - - attDao := NewAttachmentDao(db) - - // ---------------------------- - // ATTACHMENTS.CREATED_AT DESC - // ---------------------------- - dbParams := &database.DatabaseParams{ - OrderBy: []string{attDao.Table() + ".created_at desc"}, - IncludeRelations: []string{attDao.Table()}, - } - - result, err := dao.Get(testData[0].Assets[0].ID, dbParams, nil) - require.Nil(t, err) - require.Equal(t, testData[0].Assets[0].ID, result.ID) - require.Equal(t, testData[0].Assets[0].Attachments[0].ID, result.Attachments[1].ID) - require.Equal(t, testData[0].Assets[0].Attachments[1].ID, result.Attachments[0].ID) - - // ---------------------------- - // ATTACHMENTS.CREATED_AT ASC - // ---------------------------- - dbParams = &database.DatabaseParams{ - OrderBy: []string{attDao.Table() + ".created_at asc"}, - IncludeRelations: []string{attDao.Table()}, - } - - result, err = dao.Get(testData[0].Assets[0].ID, dbParams, nil) - require.Nil(t, err) - require.Equal(t, testData[0].Assets[0].ID, result.ID) - require.Equal(t, testData[0].Assets[0].Attachments[0].ID, result.Attachments[0].ID) - require.Equal(t, testData[0].Assets[0].Attachments[1].ID, result.Attachments[1].ID) - - // ---------------------------- - // Error - // ---------------------------- - dbParams = &database.DatabaseParams{ - OrderBy: []string{attDao.Table() + ".unit_test desc"}, - IncludeRelations: []string{attDao.Table()}, - } - - result, err = dao.Get(testData[0].Assets[0].ID, dbParams, nil) - require.ErrorContains(t, err, "no such column") - require.Nil(t, result) - }) - - t.Run("not found", func(t *testing.T) { - dao, _ := assetSetup(t) - - c, err := dao.Get("1234", nil, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, c) - }) - - t.Run("empty id", func(t *testing.T) { - dao, _ := assetSetup(t) - - c, err := dao.Get("", nil, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, c) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := assetSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Get("1234", nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAsset_List(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := assetSetup(t) - - assets, err := dao.List(nil, nil) - require.Nil(t, err) - require.Zero(t, assets) - }) - - t.Run("found", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(5).Assets(2).Attachments(3).Build() - - result, err := dao.List(nil, nil) - require.Nil(t, err) - require.Len(t, result, 10) - require.Nil(t, result[0].Attachments) - - // ---------------------------- - // Progress - // ---------------------------- - apDao := NewAssetProgressDao(db) - - for _, a := range result { - require.Zero(t, a.VideoPos) - require.False(t, a.Completed) - require.True(t, a.CompletedAt.IsZero()) - } - - // Update video position for the first asset (This will create the asset progress) - ap1 := &models.AssetProgress{ - AssetID: testData[0].Assets[0].ID, - CourseID: testData[0].ID, - VideoPos: 50, - } - - require.Nil(t, apDao.Update(ap1, nil)) - // Find all started videos - dbParams := &database.DatabaseParams{ - Where: squirrel.And{ - squirrel.Eq{dao.Table() + ".type": string(types.AssetVideo)}, - squirrel.Gt{apDao.Table() + ".video_pos": 0}, - }, - } - result, err = dao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, result, 1) - require.Equal(t, testData[0].Assets[0].ID, result[0].ID) - require.Equal(t, 50, result[0].VideoPos) - - // Mark the second asset as completed - ap2 := &models.AssetProgress{ - AssetID: testData[1].Assets[1].ID, - CourseID: testData[1].ID, - Completed: true, - CompletedAt: time.Now(), - } - - require.Nil(t, apDao.Update(ap2, nil)) - - // Find completed assets - result, err = dao.List(&database.DatabaseParams{Where: squirrel.Eq{apDao.Table() + ".completed": true}}, nil) - require.Nil(t, err) - require.Len(t, result, 1) - require.Equal(t, testData[1].Assets[1].ID, result[0].ID) - require.True(t, result[0].Completed) - require.False(t, result[0].CompletedAt.IsZero()) - - // ---------------------------- - // Attachments - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{IncludeRelations: []string{NewAttachmentDao(dao.db).Table()}}, nil) - require.Nil(t, err) - require.Len(t, result, 10) - - for _, a := range result { - require.Len(t, a.Attachments, 3) - } - }) - - t.Run("orderby", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Assets(1).Attachments(2).Build() - - // ---------------------------- - // CREATED_AT DESC - // ---------------------------- - dbParams := &database.DatabaseParams{OrderBy: []string{"created_at desc"}} - result, err := dao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, result, 3) - require.Equal(t, testData[2].Assets[0].ID, result[0].ID) - - // ---------------------------- - // CREATED_AT ASC - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{OrderBy: []string{"created_at asc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 3) - require.Equal(t, testData[0].Assets[0].ID, result[0].ID) - - // ---------------------------- - // CREATED_AT ASC + ATTACHMENTS.CREATED_AT DESC - // ---------------------------- - attachmentsDao := NewAttachmentDao(db) - - result, err = dao.List(&database.DatabaseParams{ - OrderBy: []string{ - dao.Table() + ".created_at asc", - attachmentsDao.Table() + ".created_at desc", - }, - IncludeRelations: []string{attachmentsDao.Table()}, - }, nil) - - require.Nil(t, err) - require.Len(t, result, 3) - require.Equal(t, testData[0].Assets[0].ID, result[0].ID) - require.Equal(t, testData[0].Assets[0].Attachments[1].ID, result[0].Attachments[0].ID) - - // ---------------------------- - // Error - // ---------------------------- - dbParams = &database.DatabaseParams{OrderBy: []string{"unit_test asc"}} - result, err = dao.List(dbParams, nil) - require.ErrorContains(t, err, "no such column") - require.Nil(t, result) - }) - - t.Run("where", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Assets(2).Build() - - // ---------------------------- - // EQUALS ID - // ---------------------------- - result, err := dao.List(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".id": testData[0].Assets[1].ID}}, nil) - require.Nil(t, err) - require.Len(t, result, 1) - require.Equal(t, testData[0].Assets[1].ID, result[0].ID) - - // ---------------------------- - // EQUALS ID OR ID - // ---------------------------- - dbParams := &database.DatabaseParams{ - Where: squirrel.Or{squirrel.Eq{dao.Table() + ".id": testData[0].Assets[1].ID}, squirrel.Eq{dao.Table() + ".id": testData[1].Assets[1].ID}}, - OrderBy: []string{"created_at asc"}, - } - result, err = dao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, result, 2) - require.Equal(t, testData[0].Assets[1].ID, result[0].ID) - require.Equal(t, testData[1].Assets[1].ID, result[1].ID) - - // ---------------------------- - // ERROR - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) - require.ErrorContains(t, err, "syntax error") - require.Nil(t, result) - }) - - t.Run("pagination", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(17).Build() - - // ---------------------------- - // Page 1 with 10 items - // ---------------------------- - p := pagination.New(1, 10) - - result, err := dao.List(&database.DatabaseParams{Pagination: p}, nil) - require.Nil(t, err) - require.Len(t, result, 10) - require.Equal(t, 17, p.TotalItems()) - require.Equal(t, testData[0].Assets[0].ID, result[0].ID) - require.Equal(t, testData[0].Assets[9].ID, result[9].ID) - - // ---------------------------- - // Page 2 with 7 items - // ---------------------------- - p = pagination.New(2, 10) - - result, err = dao.List(&database.DatabaseParams{Pagination: p}, nil) - require.Nil(t, err) - require.Len(t, result, 7) - require.Equal(t, 17, p.TotalItems()) - require.Equal(t, testData[0].Assets[10].ID, result[0].ID) - require.Equal(t, testData[0].Assets[16].ID, result[6].ID) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := assetSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.List(nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAsset_Update(t *testing.T) { - t.Run("update", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - testData[0].Assets[0].Title = security.PseudorandomString(6) - testData[0].Assets[0].Prefix = sql.NullInt16{Int16: int16(rand.Intn(200-1) + 1), Valid: true} - testData[0].Assets[0].Chapter = fmt.Sprintf("%s chapter %s", security.PseudorandomString(3), security.PseudorandomString(2)) - testData[0].Assets[0].Type = *types.NewAsset("html") - testData[0].Assets[0].Path = fmt.Sprintf("%s/%s/%d %s.mp4", testData[0].Path, testData[0].Assets[0].Chapter, testData[0].Assets[0].Prefix.Int16, testData[0].Assets[0].Title) - - require.Nil(t, dao.Update(testData[0].Assets[0], nil)) - - a, err := dao.Get(testData[0].Assets[0].ID, nil, nil) - require.Nil(t, err) - require.Equal(t, testData[0].Assets[0].Prefix, a.Prefix) - require.Equal(t, testData[0].Assets[0].Title, a.Title) - require.Equal(t, testData[0].Assets[0].Chapter, a.Chapter) - require.Equal(t, testData[0].Assets[0].Path, a.Path) - require.True(t, a.Type.IsHTML()) - - }) - - t.Run("empty id", func(t *testing.T) { - dao, _ := assetSetup(t) - - err := dao.Update(&models.Asset{}, nil) - require.ErrorIs(t, err, ErrEmptyId) - }) - - t.Run("invalid id", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - testData[0].Assets[0].ID = "1234" - require.Nil(t, dao.Update(testData[0].Assets[0], nil)) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - err = dao.Update(testData[0].Assets[0], nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAsset_Delete(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - err := dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": testData[0].Assets[0].ID}}, nil) - require.Nil(t, err) - }) - - t.Run("no db params", func(t *testing.T) { - dao, _ := assetSetup(t) - - err := dao.Delete(nil, nil) - require.ErrorIs(t, err, ErrMissingWhere) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := assetSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - err = dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": "1234"}}, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAsset_DeleteCascade(t *testing.T) { - dao, db := assetSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - // Delete the course - courseDao := NewCourseDao(db) - err := courseDao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": testData[0].ID}}, nil) - require.Nil(t, err) - - // Check the asset was deleted - a, err := dao.Get(testData[0].Assets[0].ID, nil, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, a) -} diff --git a/daos/attachment.go b/daos/attachment.go deleted file mode 100644 index 99318a3..0000000 --- a/daos/attachment.go +++ /dev/null @@ -1,249 +0,0 @@ -package daos - -import ( - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// AttachmentDao is the data access object for attachments -type AttachmentDao struct { - db database.Database - table string -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NewAttachmentDao returns a new AttachmentDao -func NewAttachmentDao(db database.Database) *AttachmentDao { - return &AttachmentDao{ - db: db, - table: "attachments", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Table returns the table name -func (dao *AttachmentDao) Table() string { - return dao.table -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Count returns the number of attachments -func (dao *AttachmentDao) Count(params *database.DatabaseParams, tx *database.Tx) (int, error) { - generic := NewGenericDao(dao.db, dao) - return generic.Count(params, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Create inserts a new attachment -func (dao *AttachmentDao) Create(a *models.Attachment, tx *database.Tx) error { - if a.ID == "" { - a.RefreshId() - } - - a.RefreshCreatedAt() - a.RefreshUpdatedAt() - - query, args, _ := squirrel. - StatementBuilder. - Insert(dao.Table()). - SetMap(dao.data(a)). - ToSql() - - execFn := dao.db.Exec - if tx != nil { - execFn = tx.Exec - } - - _, err := execFn(query, args...) - - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Get selects an attachment with the given ID -// -// `tx` allows for the function to be run within a transaction -func (dao *AttachmentDao) Get(id string, tx *database.Tx) (*models.Attachment, error) { - generic := NewGenericDao(dao.db, dao) - - dbParams := &database.DatabaseParams{ - Columns: dao.columns(), - Where: squirrel.Eq{dao.Table() + ".id": id}, - } - - row, err := generic.Get(dbParams, tx) - if err != nil { - return nil, err - } - - attachment, err := dao.scanRow(row) - if err != nil { - return nil, err - } - - return attachment, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// List selects attachments -// -// `tx` allows for the function to be run within a transaction -func (dao *AttachmentDao) List(dbParams *database.DatabaseParams, tx *database.Tx) ([]*models.Attachment, error) { - generic := NewGenericDao(dao.db, dao) - - if dbParams == nil { - dbParams = &database.DatabaseParams{} - } - - // Process the order by clauses - dbParams.OrderBy = dao.ProcessOrderBy(dbParams.OrderBy, false) - - // Default the columns if not specified - if len(dbParams.Columns) == 0 { - dbParams.Columns = dao.columns() - } - - rows, err := generic.List(dbParams, tx) - if err != nil { - return nil, err - } - defer rows.Close() - - var attachments []*models.Attachment - - for rows.Next() { - a, err := dao.scanRow(rows) - if err != nil { - return nil, err - } - - attachments = append(attachments, a) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return attachments, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Delete deletes an attachment based upon the where clause -// -// `tx` allows for the function to be run within a transaction -func (dao *AttachmentDao) Delete(dbParams *database.DatabaseParams, tx *database.Tx) error { - if dbParams == nil || dbParams.Where == nil { - return ErrMissingWhere - } - - generic := NewGenericDao(dao.db, dao) - return generic.Delete(dbParams, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// ProcessOrderBy takes an array of strings representing orderBy clauses and returns a processed -// version of this array -// -// It will creates a new list of valid table columns based upon columns() for the current -// DAO -func (dao *AttachmentDao) ProcessOrderBy(orderBy []string, explicit bool) []string { - if len(orderBy) == 0 { - return orderBy - } - - generic := NewGenericDao(dao.db, dao) - return generic.ProcessOrderBy(orderBy, dao.columns(), explicit) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Internal -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// countSelect returns the default count select builder -func (dao *AttachmentDao) countSelect() squirrel.SelectBuilder { - return squirrel. - StatementBuilder. - PlaceholderFormat(squirrel.Question). - Select(""). - From(dao.Table()). - RemoveColumns() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// baseSelect returns the default select builder -// -// Note: The columns are removed, so you must specify the columns with `.Columns(...)` when using -// this select builder -func (dao *AttachmentDao) baseSelect() squirrel.SelectBuilder { - return dao.countSelect() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// columns returns the columns to select -func (dao *AttachmentDao) columns() []string { - return []string{ - dao.Table() + ".*", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// data generates a map of key/values for an attachment -func (dao *AttachmentDao) data(a *models.Attachment) map[string]any { - return map[string]any{ - "id": a.ID, - "course_id": NilStr(a.CourseID), - "asset_id": NilStr(a.AssetID), - "title": NilStr(a.Title), - "path": NilStr(a.Path), - "created_at": FormatTime(a.CreatedAt), - "updated_at": FormatTime(a.UpdatedAt), - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// scanRow scans an attachment row -func (dao *AttachmentDao) scanRow(scannable Scannable) (*models.Attachment, error) { - var a models.Attachment - - var createdAt string - var updatedAt string - - err := scannable.Scan( - &a.ID, - &a.CourseID, - &a.AssetID, - &a.Title, - &a.Path, - &createdAt, - &updatedAt, - ) - - if err != nil { - return nil, err - } - - if a.CreatedAt, err = ParseTime(createdAt); err != nil { - return nil, err - } - - if a.UpdatedAt, err = ParseTime(updatedAt); err != nil { - return nil, err - } - - return &a, nil -} diff --git a/daos/attachments_test.go b/daos/attachments_test.go deleted file mode 100644 index 875a89d..0000000 --- a/daos/attachments_test.go +++ /dev/null @@ -1,411 +0,0 @@ -package daos - -import ( - "database/sql" - "fmt" - "testing" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils/pagination" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func attachmentSetup(t *testing.T) (*AttachmentDao, database.Database) { - t.Helper() - - dbManager := setup(t) - attachmentDao := NewAttachmentDao(dbManager.DataDb) - return attachmentDao, dbManager.DataDb -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAttachment_Count(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := attachmentSetup(t) - - count, err := dao.Count(nil, nil) - require.Nil(t, err) - require.Zero(t, count) - }) - - t.Run("entries", func(t *testing.T) { - dao, db := attachmentSetup(t) - - NewTestBuilder(t).Db(db).Courses(5).Assets(1).Attachments(1).Build() - - count, err := dao.Count(nil, nil) - require.Nil(t, err) - require.Equal(t, count, 5) - }) - - t.Run("where", func(t *testing.T) { - dao, db := attachmentSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Assets(1).Attachments(2).Build() - - // ---------------------------- - // EQUALS ID - // ---------------------------- - count, err := dao.Count(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".id": testData[1].Assets[0].Attachments[1].ID}}, nil) - require.Nil(t, err) - require.Equal(t, 1, count) - - // ---------------------------- - // NOT EQUALS ID - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.NotEq{dao.Table() + ".id": testData[1].Assets[0].Attachments[1].ID}}, nil) - require.Nil(t, err) - require.Equal(t, 5, count) - - // ---------------------------- - // EQUALS ASSET_ID - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".asset_id": testData[1].Assets[0].ID}}, nil) - require.Nil(t, err) - require.Equal(t, 2, count) - - // ---------------------------- - // ERROR - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) - require.ErrorContains(t, err, "syntax error") - require.Zero(t, count) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := attachmentSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Count(nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAttachment_Create(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, db := attachmentSetup(t) - - testData := NewTestBuilder(t).Courses(1).Assets(1).Attachments(1).Build() - - // Create the course - courseDao := NewCourseDao(db) - require.Nil(t, courseDao.Create(testData[0].Course)) - - // Create the asset - assetDao := NewAssetDao(db) - err := assetDao.Create(testData[0].Assets[0], nil) - require.Nil(t, err) - - // Create the attachment - err = dao.Create(testData[0].Assets[0].Attachments[0], nil) - require.Nil(t, err) - - newA, err := dao.Get(testData[0].Assets[0].Attachments[0].ID, nil) - require.Nil(t, err) - require.Equal(t, testData[0].Assets[0].Attachments[0].ID, newA.ID) - require.Equal(t, testData[0].Assets[0].Attachments[0].CourseID, newA.CourseID) - require.Equal(t, testData[0].Assets[0].Attachments[0].AssetID, newA.AssetID) - require.Equal(t, testData[0].Assets[0].Attachments[0].Title, newA.Title) - require.Equal(t, testData[0].Assets[0].Attachments[0].Path, newA.Path) - require.False(t, newA.CreatedAt.IsZero()) - require.False(t, newA.UpdatedAt.IsZero()) - }) - - t.Run("duplicate paths", func(t *testing.T) { - dao, db := attachmentSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Attachments(1).Build() - - // Create the attachment (again) - err := dao.Create(testData[0].Assets[0].Attachments[0], nil) - require.ErrorContains(t, err, fmt.Sprintf("UNIQUE constraint failed: %s.path", dao.Table())) - }) - - t.Run("constraints", func(t *testing.T) { - dao, db := attachmentSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Build() - - // No course ID - attachment := &models.Attachment{} - require.ErrorContains(t, dao.Create(attachment, nil), fmt.Sprintf("NOT NULL constraint failed: %s.course_id", dao.Table())) - attachment.CourseID = "" - require.ErrorContains(t, dao.Create(attachment, nil), fmt.Sprintf("NOT NULL constraint failed: %s.course_id", dao.Table())) - attachment.CourseID = "1234" - - // No asset ID - require.ErrorContains(t, dao.Create(attachment, nil), fmt.Sprintf("NOT NULL constraint failed: %s.asset_id", dao.Table())) - attachment.AssetID = "" - require.ErrorContains(t, dao.Create(attachment, nil), fmt.Sprintf("NOT NULL constraint failed: %s.asset_id", dao.Table())) - attachment.AssetID = "1234" - - // No title - require.ErrorContains(t, dao.Create(attachment, nil), fmt.Sprintf("NOT NULL constraint failed: %s.title", dao.Table())) - attachment.Title = "" - require.ErrorContains(t, dao.Create(attachment, nil), fmt.Sprintf("NOT NULL constraint failed: %s.title", dao.Table())) - attachment.Title = "Course 1" - - // No path - require.ErrorContains(t, dao.Create(attachment, nil), fmt.Sprintf("NOT NULL constraint failed: %s.path", dao.Table())) - attachment.Path = "" - require.ErrorContains(t, dao.Create(attachment, nil), fmt.Sprintf("NOT NULL constraint failed: %s.path", dao.Table())) - attachment.Path = "/course 1/01 attachment" - - // Invalid course ID - require.ErrorContains(t, dao.Create(attachment, nil), "FOREIGN KEY constraint failed") - attachment.CourseID = testData[0].ID - - // Invalid asset ID - require.ErrorContains(t, dao.Create(attachment, nil), "FOREIGN KEY constraint failed") - attachment.AssetID = testData[0].Assets[0].ID - - // Success - require.Nil(t, dao.Create(attachment, nil)) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAttachment_Get(t *testing.T) { - t.Run("found", func(t *testing.T) { - dao, db := attachmentSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(2).Assets(1).Attachments(1).Build() - - a, err := dao.Get(testData[1].Assets[0].Attachments[0].ID, nil) - require.Nil(t, err) - require.Equal(t, testData[1].Assets[0].Attachments[0].ID, a.ID) - }) - - t.Run("not found", func(t *testing.T) { - dao, _ := attachmentSetup(t) - - c, err := dao.Get("1234", nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, c) - }) - - t.Run("empty id", func(t *testing.T) { - dao, _ := attachmentSetup(t) - - c, err := dao.Get("", nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, c) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := attachmentSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Get("1234", nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAttachment_List(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := attachmentSetup(t) - - assets, err := dao.List(nil, nil) - require.Nil(t, err) - require.Zero(t, assets) - }) - - t.Run("found", func(t *testing.T) { - dao, db := attachmentSetup(t) - - NewTestBuilder(t).Db(db).Courses(5).Assets(1).Attachments(1).Build() - - result, err := dao.List(nil, nil) - require.Nil(t, err) - require.Len(t, result, 5) - - }) - - t.Run("orderby", func(t *testing.T) { - dao, db := attachmentSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Assets(1).Attachments(1).Build() - - // ---------------------------- - // CREATED_AT DESC - // ---------------------------- - result, err := dao.List(&database.DatabaseParams{OrderBy: []string{"created_at desc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 3) - require.Equal(t, testData[2].Assets[0].Attachments[0].ID, result[0].ID) - require.Equal(t, testData[1].Assets[0].Attachments[0].ID, result[1].ID) - require.Equal(t, testData[0].Assets[0].Attachments[0].ID, result[2].ID) - - // ---------------------------- - // CREATED_AT ASC - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{OrderBy: []string{"created_at asc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 3) - require.Equal(t, testData[0].Assets[0].Attachments[0].ID, result[0].ID) - require.Equal(t, testData[1].Assets[0].Attachments[0].ID, result[1].ID) - require.Equal(t, testData[2].Assets[0].Attachments[0].ID, result[2].ID) - - // ---------------------------- - // Error - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{OrderBy: []string{"unit_test asc"}}, nil) - require.ErrorContains(t, err, "no such column") - require.Nil(t, result) - }) - - t.Run("where", func(t *testing.T) { - dao, db := attachmentSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Assets(2).Attachments(2).Build() - - // ---------------------------- - // EQUALS ID - // ---------------------------- - result, err := dao.List(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".id": testData[1].Assets[1].Attachments[0].ID}}, nil) - require.Nil(t, err) - require.Len(t, result, 1) - require.Equal(t, testData[1].Assets[1].Attachments[0].ID, result[0].ID) - - // ---------------------------- - // EQUALS ID OR ID - // ---------------------------- - dbParams := &database.DatabaseParams{ - Where: squirrel.Or{ - squirrel.Eq{dao.Table() + ".id": testData[1].Assets[1].Attachments[0].ID}, - squirrel.Eq{dao.Table() + ".id": testData[2].Assets[0].Attachments[1].ID}, - }, - OrderBy: []string{"created_at asc"}, - } - - result, err = dao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, result, 2) - require.Equal(t, testData[1].Assets[1].Attachments[0].ID, result[0].ID) - require.Equal(t, testData[2].Assets[0].Attachments[1].ID, result[1].ID) - - // ---------------------------- - // ERROR - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) - require.ErrorContains(t, err, "syntax error") - require.Nil(t, result) - }) - - t.Run("pagination", func(t *testing.T) { - dao, db := attachmentSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Attachments(17).Build() - - // ---------------------------- - // Page 1 with 10 items - // ---------------------------- - p := pagination.New(1, 10) - - result, err := dao.List(&database.DatabaseParams{Pagination: p}, nil) - require.Nil(t, err) - require.Len(t, result, 10) - require.Equal(t, 17, p.TotalItems()) - require.Equal(t, testData[0].Assets[0].Attachments[0].ID, result[0].ID) - require.Equal(t, testData[0].Assets[0].Attachments[9].ID, result[9].ID) - - // ---------------------------- - // Page 2 with 7 items - // ---------------------------- - p = pagination.New(2, 10) - - result, err = dao.List(&database.DatabaseParams{Pagination: p}, nil) - require.Nil(t, err) - require.Len(t, result, 7) - require.Equal(t, 17, p.TotalItems()) - require.Equal(t, testData[0].Assets[0].Attachments[10].ID, result[0].ID) - require.Equal(t, testData[0].Assets[0].Attachments[16].ID, result[6].ID) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := attachmentSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.List(nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAttachment_Delete(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, db := attachmentSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(1).Attachments(1).Build() - err := dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": testData[0].Assets[0].Attachments[0].ID}}, nil) - require.Nil(t, err) - }) - - t.Run("no db params", func(t *testing.T) { - dao, _ := attachmentSetup(t) - - err := dao.Delete(nil, nil) - require.ErrorIs(t, err, ErrMissingWhere) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := attachmentSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - err = dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": "1234"}}, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestAttachment_DeleteCascade(t *testing.T) { - t.Run("delete course", func(t *testing.T) { - dao, db := attachmentSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(2).Assets(1).Attachments(1).Build() - - // Delete the course - courseDao := NewCourseDao(db) - err := courseDao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": testData[0].ID}}, nil) - require.Nil(t, err) - - // Check the asset was deleted - s, err := dao.Get(testData[0].Assets[0].Attachments[0].ID, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, s) - }) - - t.Run("delete asset", func(t *testing.T) { - dao, db := attachmentSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(2).Assets(1).Attachments(1).Build() - - // Delete the asset - assetDao := NewAssetDao(db) - err := assetDao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": testData[0].Assets[0].ID}}, nil) - require.Nil(t, err) - - _, err = dao.Get(testData[0].Assets[0].Attachments[0].ID, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - }) -} diff --git a/daos/base.go b/daos/base.go deleted file mode 100644 index 8e2425d..0000000 --- a/daos/base.go +++ /dev/null @@ -1,132 +0,0 @@ -package daos - -import ( - "database/sql" - "errors" - "strings" - "time" - - "github.com/Masterminds/squirrel" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Scannable is an interface for a database row that can be scanned into a struct -type Scannable interface { - Scan(dest ...interface{}) error -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -type daoer interface { - Table() string - countSelect() squirrel.SelectBuilder - baseSelect() squirrel.SelectBuilder -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Errors -var ( - ErrEmptyId = errors.New("id cannot be empty") - ErrMissingCourseId = errors.New("course id cannot be empty") - ErrMissingWhere = errors.New("where clause cannot be empty") - ErrInvalidPrefix = errors.New("prefix must be greater than 0") - ErrNilTransaction = errors.New("transaction cannot be nil") - ErrMissingTag = errors.New("tag cannot be empty") -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// FormatTime formats a time.Time to a string -func FormatTime(t time.Time) string { - return t.Format("2006-01-02 15:04:05.000") -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// ParseTime parses the time string from SQLite to time.Time -func ParseTime(t string) (time.Time, error) { - if t == "" { - return time.Time{}, nil - } - return time.Parse("2006-01-02 15:04:05.000", t) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// ParseTimeNull parses a time string from a sql.NullString to time.Time -func ParseTimeNull(t sql.NullString) (time.Time, error) { - if t.Valid { - if value, err := ParseTime(t.String); err != nil { - return time.Time{}, err - } else { - return value, nil - } - } else { - return time.Time{}, nil - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NilStr returns nil when a string is empty -// -// Use this when inserting into the database to avoid inserting empty strings -func NilStr(s string) any { - if s == "" { - return nil - } - - return s -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// extractTableColumn extracts the table and column name from an orderBy string. If no table prefix -// is found, the table part is returned as an empty string -func extractTableColumn(orderBy string) (string, string) { - parts := strings.Fields(orderBy) - tableColumn := strings.Split(parts[0], ".") - - if len(tableColumn) == 2 { - return tableColumn[0], tableColumn[1] - } - - return "", tableColumn[0] -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// isValidOrderBy returns true if the orderBy string is valid. The table and column are validated -// against the given list of valid table.columns (ex. courses.id, scans.status as scan_status). -func isValidOrderBy(table, column string, validateTableColumns []string) bool { - // If the column is empty, always return false - if column == "" { - return false - } - - for _, validTc := range validateTableColumns { - // Wildcard match (ex. courses.* == id) - if table == "" && strings.HasSuffix(validTc, ".*") { - return true - } - - // Exact match (ex. id == id || courses.id == courses.id || courses.id as courses_id == courses.id) - if validTc == column || validTc == table+"."+column || strings.HasPrefix(validTc, table+"."+column+" as ") { - return true - } - - // Table + wildcard match (ex. courses.* == courses.id) - if strings.HasSuffix(validTc, ".*") && strings.HasPrefix(validTc, table+".") { - return true - } - - // courses.id as course_id == course_id - if strings.HasSuffix(validTc, " as "+column) { - return true - } - } - - return false -} diff --git a/daos/base_test.go b/daos/base_test.go deleted file mode 100644 index 8e7002f..0000000 --- a/daos/base_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package daos - -import "testing" - -func Test_IsValidOrderBy(t *testing.T) { - t.Run("with wildcard", func(t *testing.T) { - validTableColumns := []string{ - "courses.*", - "scans.status as scan_status", - "courses_progress.started", - "courses_progress.percent", - } - - tests := []struct { - name string - table string - column string - expected bool - }{ - // Wildcard match - {"valid .*", "", "nonexistent", true}, - {"valid .* and direct", "", "percent", true}, - {"valid .* and alias", "", "scan_status", true}, - - // Table.* match - {"valid table.*", "courses", "title", true}, - - // Table.column match - {"valid table.column", "courses_progress", "started", true}, - {"valid table.column as alias", "scans", "status", true}, - - // Invalid - {"invalid table.column", "test", "invalid", false}, - {"invalid column", "", "", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := isValidOrderBy(tt.table, tt.column, validTableColumns) - if result != tt.expected { - t.Errorf("isValidOrderBy(%s, %s) = %v; expected %v", tt.table, tt.column, result, tt.expected) - } - }) - } - }) - - t.Run("without wildcard", func(t *testing.T) { - validTableColumns := []string{ - "courses.id", - "data", - "scans.status as scan_status", - "courses_progress.started", - "courses_progress.percent", - } - - tests := []struct { - name string - table string - column string - expected bool - }{ - // Exact - {"valid direct", "", "data", true}, - {"valid direct as alias", "", "scan_status", true}, - - // Table.column - {"valid table.column", "courses_progress", "started", true}, - {"valid table.column as alias", "scans", "status", true}, - - // Wildcard - {"invalid .*", "", "nonexistent", false}, - {"invalid .* and direct", "", "percent", false}, - {"invalid .* and alias", "", "status", false}, - - // Table.* - {"invalid table.*", "courses", "title", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := isValidOrderBy(tt.table, tt.column, validTableColumns) - if result != tt.expected { - t.Errorf("isValidOrderBy(%s, %s) = %v; expected %v", tt.table, tt.column, result, tt.expected) - } - }) - } - }) -} diff --git a/daos/common_test.go b/daos/common_test.go deleted file mode 100644 index b8db430..0000000 --- a/daos/common_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package daos - -import ( - "sync" - "testing" - - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/utils/appFs" - "github.com/geerew/off-course/utils/logger" - "github.com/spf13/afero" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func setup(t *testing.T) *database.DatabaseManager { - t.Helper() - - // Logger - var logs []*logger.Log - var logsMux sync.Mutex - logger, _, err := logger.InitLogger(&logger.BatchOptions{ - BatchSize: 1, - WriteFn: logger.TestWriteFn(&logs, &logsMux), - }) - require.NoError(t, err, "Failed to initialize logger") - - // Filesystem - appFs := appFs.NewAppFs(afero.NewMemMapFs(), logger) - - // DB - dbManager, err := database.NewSqliteDBManager(&database.DatabaseConfig{ - IsDebug: false, - DataDir: "./oc_data", - AppFs: appFs, - InMemory: true, - }) - - require.Nil(t, err) - require.NotNil(t, dbManager) - - return dbManager -} diff --git a/daos/course.go b/daos/course.go deleted file mode 100644 index 995abfc..0000000 --- a/daos/course.go +++ /dev/null @@ -1,459 +0,0 @@ -package daos - -import ( - "database/sql" - "slices" - "strings" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils/types" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// CourseDao is the data access object for courses -type CourseDao struct { - db database.Database - table string -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NewCourseDao returns a new CourseDao -func NewCourseDao(db database.Database) *CourseDao { - return &CourseDao{ - db: db, - table: "courses", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Table returns the table name -func (dao *CourseDao) Table() string { - return dao.table -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Count returns the number of courses -func (dao *CourseDao) Count(params *database.DatabaseParams) (int, error) { - generic := NewGenericDao(dao.db, dao) - return generic.Count(params, nil) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Create inserts a new course and courses_progress row within a transaction -// -// NOTE: There is currently no support for users, but when there is, the default courses_progress -// should be inserted for the admin user -func (dao *CourseDao) Create(c *models.Course) error { - if c.ID == "" { - c.RefreshId() - } - - c.RefreshCreatedAt() - c.RefreshUpdatedAt() - - query, args, _ := squirrel. - StatementBuilder. - Insert(dao.Table()). - SetMap(dao.data(c)). - ToSql() - - return dao.db.RunInTransaction(func(tx *database.Tx) error { - // Create the course - if _, err := tx.Exec(query, args...); err != nil { - return err - } - - // Create the course progress - cp := &models.CourseProgress{ - CourseID: c.ID, - } - - cpDao := NewCourseProgressDao(dao.db) - return cpDao.Create(cp, tx) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Get selects a course with the given ID -// -// `tx` allows for the function to be run within a transaction -func (dao *CourseDao) Get(id string, dbParams *database.DatabaseParams, tx *database.Tx) (*models.Course, error) { - generic := NewGenericDao(dao.db, dao) - - courseDbParams := &database.DatabaseParams{ - Columns: dao.columns(), - Where: squirrel.Eq{dao.Table() + ".id": id}, - } - - row, err := generic.Get(courseDbParams, tx) - if err != nil { - return nil, err - } - - course, err := dao.scanRow(row) - if err != nil { - return nil, err - } - - return course, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// List selects courses -// -// `tx` allows for the function to be run within a transaction -func (dao *CourseDao) List(dbParams *database.DatabaseParams, tx *database.Tx) ([]*models.Course, error) { - generic := NewGenericDao(dao.db, dao) - - if dbParams == nil { - dbParams = &database.DatabaseParams{} - } - - // Process the order by clauses - dbParams.OrderBy = dao.ProcessOrderBy(dbParams.OrderBy) - - // Default the columns if not specified - if len(dbParams.Columns) == 0 { - dbParams.Columns = dao.columns() - } - - rows, err := generic.List(dbParams, tx) - if err != nil { - return nil, err - } - defer rows.Close() - - var courses []*models.Course - - for rows.Next() { - c, err := dao.scanRow(rows) - if err != nil { - return nil, err - } - - courses = append(courses, c) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return courses, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Update updates a course -// -// Note: Only `card_path` and `available` can be updated -func (dao *CourseDao) Update(course *models.Course, tx *database.Tx) error { - if course.ID == "" { - return ErrEmptyId - } - - course.RefreshUpdatedAt() - - query, args, _ := squirrel. - StatementBuilder. - Update(dao.Table()). - Set("card_path", NilStr(course.CardPath)). - Set("available", course.Available). - Set("updated_at", FormatTime(course.UpdatedAt)). - Where("id = ?", course.ID). - ToSql() - - execFn := dao.db.Exec - if tx != nil { - execFn = tx.Exec - } - - _, err := execFn(query, args...) - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Delete deletes a course based upon the where clause -// -// `tx` allows for the function to be run within a transaction -func (dao *CourseDao) Delete(dbParams *database.DatabaseParams, tx *database.Tx) error { - if dbParams == nil || dbParams.Where == nil { - return ErrMissingWhere - } - - generic := NewGenericDao(dao.db, dao) - return generic.Delete(dbParams, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// ClassifyPaths classifies the given paths into one of the following categories: -// - PathClassificationNone: The path does not exist in the courses table -// - PathClassificationAncestor: The path is an ancestor of a course path -// - PathClassificationCourse: The path is an exact match to a course path -// - PathClassificationDescendant: The path is a descendant of a course path -// -// The paths are returned as a map with the original path as the key and the classification as the -// value -func (dao *CourseDao) ClassifyPaths(paths []string) (map[string]types.PathClassification, error) { - paths = slices.DeleteFunc(paths, func(s string) bool { - return s == "" - }) - - if len(paths) == 0 { - return nil, nil - } - - // Initialize the results map - results := make(map[string]types.PathClassification) - for _, path := range paths { - results[path] = types.PathClassificationNone - } - - // Build the where clause - whereClause := make([]squirrel.Sqlizer, len(paths)) - for i, path := range paths { - whereClause[i] = squirrel.Like{dao.Table() + ".path": path + "%"} - } - - query, args, _ := squirrel. - StatementBuilder. - Select(dao.Table() + ".path"). - From(dao.table). - Where(squirrel.Or(whereClause)). - ToSql() - - rows, err := dao.db.Query(query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - // Store the found course paths - var coursePath string - coursePaths := []string{} - for rows.Next() { - if err := rows.Scan(&coursePath); err != nil { - return nil, err - } - coursePaths = append(coursePaths, coursePath) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - // Process - for _, path := range paths { - for _, coursePath := range coursePaths { - if coursePath == path { - results[path] = types.PathClassificationCourse - break - } else if strings.HasPrefix(coursePath, path) { - results[path] = types.PathClassificationAncestor - break - } else if strings.HasPrefix(path, coursePath) && results[path] != types.PathClassificationAncestor { - results[path] = types.PathClassificationDescendant - break - } - } - } - - return results, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// ProcessOrderBy takes an array of strings representing orderBy clauses and returns a processed -// version of this array -// -// It will creates a new list of valid table columns based upon columns() for the current -// DAO. Additionally, it handles the special case of 'scan_status' column, which requires custom -// sorting logic, via a CASE statement. -// -// The custom sorting logic is defined as follows: -// - NULL values are treated as the lowest value (sorted first in ASC, last in DESC) -// - 'waiting' status is treated as the second value -// - 'processing' status is treated as the third value -func (dao *CourseDao) ProcessOrderBy(orderBy []string) []string { - if len(orderBy) == 0 { - return orderBy - } - - validTableColumns := dao.columns() - var processedOrderBy []string - - scanDao := NewScanDao(dao.db) - - for _, ob := range orderBy { - table, column := extractTableColumn(ob) - - if isValidOrderBy(table, column, validTableColumns) { - // When the column is 'scan_status', apply the custom sorting logic - if column == "scan_status" || table+"."+column == scanDao.Table()+".status" { - // Determine the sort direction, defaulting to ASC if not specified - parts := strings.Fields(ob) - sortDirection := "ASC" - if len(parts) > 1 { - sortDirection = strings.ToUpper(parts[1]) - } - - caseStmt := "CASE " + - "WHEN scan_status IS NULL THEN 1 " + - "WHEN scan_status = 'waiting' THEN 2 " + - "WHEN scan_status = 'processing' THEN 3 " + - "END " + sortDirection - - processedOrderBy = append(processedOrderBy, caseStmt) - } else { - processedOrderBy = append(processedOrderBy, ob) - } - } - } - - return processedOrderBy -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Internal -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// countSelect returns the default select builder for counting -func (dao *CourseDao) countSelect() squirrel.SelectBuilder { - sDao := NewScanDao(dao.db) - cpDao := NewCourseProgressDao(dao.db) - - return squirrel. - StatementBuilder. - PlaceholderFormat(squirrel.Question). - Select(""). - From(dao.Table()). - LeftJoin(sDao.Table() + " ON " + dao.Table() + ".id = " + sDao.Table() + ".course_id"). - LeftJoin(cpDao.Table() + " ON " + dao.Table() + ".id = " + cpDao.Table() + ".course_id"). - RemoveColumns() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// baseSelect returns the default select builder -// -// It performs 2 left joins -// - scans table to get `scan_status` -// - courses progress table to get `started`, `started_at`, `percent`, and `completed_at` -// -// Note: The columns are removed, so you must specify the columns with `.Columns(...)` when using -// this select builder -func (dao *CourseDao) baseSelect() squirrel.SelectBuilder { - return dao.countSelect() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// columns returns the columns to select -func (dao *CourseDao) columns() []string { - sDao := NewScanDao(dao.db) - cpDao := NewCourseProgressDao(dao.db) - - return []string{ - dao.Table() + ".*", - sDao.Table() + ".status as scan_status", - cpDao.Table() + ".started", - cpDao.Table() + ".started_at", - cpDao.Table() + ".percent", - cpDao.Table() + ".completed_at", - cpDao.Table() + ".updated_at as progress_updated_at", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// data generates a map of key/values for a course -func (dao *CourseDao) data(c *models.Course) map[string]any { - return map[string]any{ - "id": c.ID, - "title": NilStr(c.Title), - "path": NilStr(c.Path), - "card_path": NilStr(c.CardPath), - "available": c.Available, - "created_at": FormatTime(c.CreatedAt), - "updated_at": FormatTime(c.UpdatedAt), - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// scanRow scans a course row -func (dao *CourseDao) scanRow(scannable Scannable) (*models.Course, error) { - var c models.Course - - // Nullable fields - var cardPath sql.NullString - var scanStatus sql.NullString - var createdAt string - var updatedAt string - var startedAt sql.NullString - var completedAt sql.NullString - var progressUpdatedAt sql.NullString - - err := scannable.Scan( - // Course - &c.ID, - &c.Title, - &c.Path, - &cardPath, - &c.Available, - &createdAt, - &updatedAt, - // Scan - &scanStatus, - // Course progress - &c.Started, - &startedAt, - &c.Percent, - &completedAt, - &progressUpdatedAt, - ) - - if err != nil { - return nil, err - } - - if cardPath.Valid { - c.CardPath = cardPath.String - } - - if scanStatus.Valid { - c.ScanStatus = scanStatus.String - } - - if c.CreatedAt, err = ParseTime(createdAt); err != nil { - return nil, err - } - - if c.UpdatedAt, err = ParseTime(updatedAt); err != nil { - return nil, err - } - - if c.StartedAt, err = ParseTimeNull(startedAt); err != nil { - return nil, err - } - - if c.CompletedAt, err = ParseTimeNull(completedAt); err != nil { - return nil, err - } - - if c.ProgressUpdatedAt, err = ParseTime(progressUpdatedAt.String); err != nil { - return nil, err - } - - return &c, nil -} diff --git a/daos/course_progress.go b/daos/course_progress.go deleted file mode 100644 index f37937f..0000000 --- a/daos/course_progress.go +++ /dev/null @@ -1,281 +0,0 @@ -package daos - -import ( - "database/sql" - "math" - "time" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// CourseProgressDao is the data access object for courses progress -type CourseProgressDao struct { - db database.Database - table string -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NewCourseProgressDao returns a new CourseProgressDao -func NewCourseProgressDao(db database.Database) *CourseProgressDao { - return &CourseProgressDao{ - db: db, - table: "courses_progress", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Table returns the table name -func (dao *CourseProgressDao) Table() string { - return dao.table -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Create inserts a new course progress -// -// `tx` allows for the function to be run within a transaction -func (dao *CourseProgressDao) Create(cp *models.CourseProgress, tx *database.Tx) error { - execFn := dao.db.Exec - if tx != nil { - execFn = tx.Exec - } - - if cp.ID == "" { - cp.RefreshId() - } - - cp.RefreshCreatedAt() - cp.RefreshUpdatedAt() - - query, args, _ := squirrel. - StatementBuilder. - Insert(dao.Table()). - SetMap(dao.data(cp)). - ToSql() - - _, err := execFn(query, args...) - - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Get selects a course progress with the given course ID -// -// `tx` allows for the function to be run within a transaction -func (dao *CourseProgressDao) Get(courseId string, tx *database.Tx) (*models.CourseProgress, error) { - generic := NewGenericDao(dao.db, dao) - - dbParams := &database.DatabaseParams{ - Columns: dao.columns(), - Where: squirrel.Eq{dao.Table() + ".course_id": courseId}, - } - - row, err := generic.Get(dbParams, tx) - if err != nil { - return nil, err - } - - cp, err := dao.scanRow(row) - if err != nil { - return nil, err - } - - return cp, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Refresh does a refresh of the current course progress for the given ID -// -// It calculates the number of assets, number of completed assets and number of video assets started. It -// then calculates the percent complete and whether the course has been started or not. -// -// Based upon this calculation, -// - If the course has been started but `started_at` is null, `started_at` will be set to the current time -// - If the course is not started, `started_at` is set to null -// - If the course is 100% complete but `completed_at` is null, `completed_at` is set to the current time -// - If the course is not complete, `completed_at` is set to null -// -// `tx` allows for the function to be run within a transaction -func (dao *CourseProgressDao) Refresh(courseId string, tx *database.Tx) error { - if courseId == "" { - return ErrEmptyId - } - - queryRowFn := dao.db.QueryRow - execFn := dao.db.Exec - - if tx != nil { - queryRowFn = tx.QueryRow - execFn = tx.Exec - } - - aDao := NewAssetDao(dao.db) - apDao := NewAssetProgressDao(dao.db) - - // Count the number of assets, number of completed assets and number of video assets started for - // this course - query, args, _ := squirrel. - StatementBuilder. - PlaceholderFormat(squirrel.Question). - Select( - "COUNT(DISTINCT "+aDao.Table()+".id) AS total_count", - "SUM(CASE WHEN "+apDao.Table()+".completed THEN 1 ELSE 0 END) AS completed_count", - "SUM(CASE WHEN "+apDao.Table()+".video_pos > 0 THEN 1 ELSE 0 END) AS started_count"). - From(aDao.Table()). - LeftJoin(apDao.Table() + " ON " + aDao.Table() + ".id = " + apDao.Table() + ".asset_id"). - Where(squirrel.And{squirrel.Eq{aDao.Table() + ".course_id": courseId}}). - ToSql() - - var totalAssetCount sql.NullInt32 - var completedAssetCount sql.NullInt32 - var startedAssetCount sql.NullInt32 - err := queryRowFn(query, args...).Scan(&totalAssetCount, &completedAssetCount, &startedAssetCount) - if err != nil { - return err - } - - // Default values - isStarted := false - startedAt := time.Time{} - percent := int(math.Abs((float64(completedAssetCount.Int32) * float64(100)) / float64(totalAssetCount.Int32))) - completedAt := time.Time{} - updatedAt := time.Now() - - // When there are started assets or percent is between >0 and <=100, set started to true and set started_at - if startedAssetCount.Int32 > 0 || percent > 0 && percent <= 100 { - isStarted = true - startedAt = updatedAt - } - - // When percent is 100, set completed_at - if percent == 100 { - completedAt = startedAt - } - - builder := squirrel. - StatementBuilder. - Update(dao.Table()). - Set("started", isStarted). - Set("percent", percent). - Set("updated_at", FormatTime(updatedAt)). - Where("course_id = ?", courseId) - - if isStarted { - builder = builder.Set("started_at", squirrel.Expr("COALESCE(started_at, ?)", FormatTime(startedAt))) - } else { - builder = builder.Set("started_at", nil) - } - - if percent == 100 { - builder = builder.Set("completed_at", squirrel.Expr("COALESCE(completed_at, ?)", FormatTime(completedAt))) - } else { - builder = builder.Set("completed_at", nil) - } - - query, args, _ = builder.ToSql() - - _, err = execFn(query, args...) - return err - -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Internal -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// countSelect returns the default count select builder -func (dao *CourseProgressDao) countSelect() squirrel.SelectBuilder { - return squirrel.StatementBuilder. - PlaceholderFormat(squirrel.Question). - Select(""). - From(dao.Table()). - RemoveColumns() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// baseSelect returns the default select builder -// -// Note: The columns are removed, so you must specify the columns with `.Columns(...)` when using -// this select builder -func (dao *CourseProgressDao) baseSelect() squirrel.SelectBuilder { - return dao.countSelect() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// columns returns the columns to select -func (dao *CourseProgressDao) columns() []string { - return []string{ - dao.Table() + ".*", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// data generates a map of key/values for a course progress -func (dao *CourseProgressDao) data(cp *models.CourseProgress) map[string]any { - return map[string]any{ - "id": cp.ID, - "course_id": NilStr(cp.CourseID), - "started": cp.Started, - "started_at": FormatTime(cp.StartedAt), - "percent": cp.Percent, - "completed_at": FormatTime(cp.CompletedAt), - "created_at": FormatTime(cp.CreatedAt), - "updated_at": FormatTime(cp.UpdatedAt), - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// scanRow scans a courses progress row -func (dao *CourseProgressDao) scanRow(scannable Scannable) (*models.CourseProgress, error) { - var cp models.CourseProgress - - var createdAt string - var updatedAt string - var startedAt sql.NullString - var completedAt sql.NullString - - err := scannable.Scan( - &cp.ID, - &cp.CourseID, - &cp.Started, - &startedAt, - &cp.Percent, - &completedAt, - &createdAt, - &updatedAt, - ) - - if err != nil { - return nil, err - } - - if cp.CreatedAt, err = ParseTime(createdAt); err != nil { - return nil, err - } - - if cp.UpdatedAt, err = ParseTime(updatedAt); err != nil { - return nil, err - } - - if cp.StartedAt, err = ParseTimeNull(startedAt); err != nil { - return nil, err - } - - if cp.CompletedAt, err = ParseTimeNull(completedAt); err != nil { - return nil, err - } - - return &cp, nil -} diff --git a/daos/course_progress_test.go b/daos/course_progress_test.go deleted file mode 100644 index e3043d5..0000000 --- a/daos/course_progress_test.go +++ /dev/null @@ -1,258 +0,0 @@ -package daos - -import ( - "database/sql" - "fmt" - "testing" - "time" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func CourseProgressSetup(t *testing.T) (*CourseProgressDao, database.Database) { - t.Helper() - - dbManager := setup(t) - cpDao := NewCourseProgressDao(dbManager.DataDb) - return cpDao, dbManager.DataDb -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseProgress_Create(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, db := CourseProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - - cp, err := dao.Get(testData[0].ID, nil) - require.Nil(t, err) - require.False(t, cp.Started) - require.True(t, cp.StartedAt.IsZero()) - require.Zero(t, cp.Percent) - require.True(t, cp.CompletedAt.IsZero()) - require.False(t, cp.CreatedAt.IsZero()) - require.False(t, cp.UpdatedAt.IsZero()) - }) - - t.Run("duplicate course id", func(t *testing.T) { - dao, db := CourseProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - - cp, err := dao.Get(testData[0].ID, nil) - require.Nil(t, err) - - err = dao.Create(cp, nil) - require.ErrorContains(t, err, fmt.Sprintf("UNIQUE constraint failed: %s.course_id", dao.Table())) - }) - - t.Run("constraint errors", func(t *testing.T) { - dao, db := CourseProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - - // Delete the courses_progress row using squirrel - query, args, _ := squirrel.StatementBuilder.Delete(dao.Table()).Where(squirrel.Eq{"course_id": testData[0].ID}).ToSql() - _, err := db.Exec(query, args...) - require.Nil(t, err) - - // Course ID - cp := &models.CourseProgress{} - require.ErrorContains(t, dao.Create(cp, nil), fmt.Sprintf("NOT NULL constraint failed: %s.course_id", dao.Table())) - cp.CourseID = "" - require.ErrorContains(t, dao.Create(cp, nil), fmt.Sprintf("NOT NULL constraint failed: %s.course_id", dao.Table())) - cp.CourseID = "1234" - - // Invalid Course ID - require.ErrorContains(t, dao.Create(cp, nil), "FOREIGN KEY constraint failed") - cp.CourseID = testData[0].ID - - // Success - require.Nil(t, dao.Create(cp, nil)) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseProgress_Get(t *testing.T) { - t.Run("found", func(t *testing.T) { - dao, db := CourseProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - - cp, err := dao.Get(testData[0].ID, nil) - require.Nil(t, err) - require.Equal(t, testData[0].ID, cp.CourseID) - }) - - t.Run("not found", func(t *testing.T) { - dao, _ := CourseProgressSetup(t) - - cp, err := dao.Get("1234", nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, cp) - }) - - t.Run("empty id", func(t *testing.T) { - dao, _ := CourseProgressSetup(t) - - cp, err := dao.Get("", nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, cp) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := CourseProgressSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Get("1234", nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseProgress_Update(t *testing.T) { - t.Run("status", func(t *testing.T) { - dao, db := CourseProgressSetup(t) - - // Create a course with 2 assets - apDao := NewAssetProgressDao(db) - testData := NewTestBuilder(t).Db(db).Courses(1).Assets(2).Build() - - // Create an asset progress for the first asset - assetProgressDao := NewAssetProgressDao(db) - ap1 := &models.AssetProgress{ - AssetID: testData[0].Assets[0].ID, - CourseID: testData[0].ID, - } - - require.Nil(t, assetProgressDao.Create(ap1, nil)) - - // Ensure the course percent is 0, started is false, and the started_at and completed_at are not set - origCp, err := dao.Get(testData[0].ID, nil) - require.Nil(t, err) - require.False(t, origCp.Started) - require.True(t, origCp.StartedAt.IsZero()) - require.Zero(t, origCp.Percent) - require.True(t, origCp.CompletedAt.IsZero()) - - // ---------------------------- - // Set the first asset to completed - // ---------------------------- - time.Sleep(1 * time.Millisecond) - ap1.Completed = true - require.Nil(t, apDao.Update(ap1, nil)) - - // Check the course percent is 50, started is true, started_at is set and completed_at is not set - updatedCp1, err := dao.Get(origCp.CourseID, nil) - require.Nil(t, err) - require.True(t, updatedCp1.Started) - require.False(t, updatedCp1.StartedAt.IsZero()) - require.Equal(t, 50, updatedCp1.Percent) - require.True(t, updatedCp1.CompletedAt.IsZero()) - - // ---------------------------- - // Set the second asset to completed - // ---------------------------- - ap2 := &models.AssetProgress{ - AssetID: testData[0].Assets[1].ID, - CourseID: testData[0].ID, - Completed: true, - } - - require.Nil(t, apDao.Create(ap2, nil)) - - // Check the course percent is 100, started is true, and started_at and completed_at are set - updatedCp2, err := dao.Get(origCp.CourseID, nil) - require.Nil(t, err) - require.True(t, updatedCp2.Started) - require.False(t, updatedCp2.StartedAt.IsZero()) - require.Equal(t, updatedCp2.StartedAt.String(), updatedCp1.StartedAt.String()) - require.Equal(t, 100, updatedCp2.Percent) - require.False(t, updatedCp2.CompletedAt.IsZero()) - - // ---------------------------- - // Set the second asset as uncompleted - // ---------------------------- - ap2.Completed = false - require.Nil(t, apDao.Update(ap2, nil)) - - // Check the course percent is 50, started is true, started_at is set and completed_at is not set - updatedCp3, err := dao.Get(origCp.CourseID, nil) - require.Nil(t, err) - require.True(t, updatedCp3.Started) - require.False(t, updatedCp3.StartedAt.IsZero()) - require.Equal(t, updatedCp3.StartedAt.String(), updatedCp2.StartedAt.String()) - require.Equal(t, 50, updatedCp3.Percent) - require.True(t, updatedCp3.CompletedAt.IsZero()) - - // ---------------------------- - // Set the first asset as uncompleted - // ---------------------------- - time.Sleep(1 * time.Millisecond) - ap1.Completed = false - require.Nil(t, apDao.Update(ap1, nil)) - - // Check the percent is 0, started is false and started_at and completed_at are not set - updatedCp4, err := dao.Get(origCp.CourseID, nil) - require.Nil(t, err) - require.False(t, updatedCp4.Started) - require.True(t, updatedCp4.StartedAt.IsZero()) - require.Zero(t, updatedCp4.Percent) - require.True(t, updatedCp4.CompletedAt.IsZero()) - }) - - t.Run("empty id", func(t *testing.T) { - dao, db := CourseProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - origCp, err := dao.Get(testData[0].ID, nil) - require.Nil(t, err) - - origCp.CourseID = "" - - err = dao.Refresh(origCp.CourseID, nil) - require.EqualError(t, err, "id cannot be empty") - }) - - t.Run("db error", func(t *testing.T) { - dao, db := CourseProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - origCp, err := dao.Get(testData[0].ID, nil) - require.Nil(t, err) - - _, err = db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - err = dao.Refresh(origCp.CourseID, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseProgress_DeleteCascade(t *testing.T) { - dao, db := CourseProgressSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - - // Delete the course - courseDao := NewCourseDao(db) - err := courseDao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": testData[0].ID}}, nil) - require.Nil(t, err) - - // Check the course progress was deleted - cp, err := dao.Get(testData[0].ID, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, cp) -} diff --git a/daos/course_tag.go b/daos/course_tag.go deleted file mode 100644 index 3f7f6b8..0000000 --- a/daos/course_tag.go +++ /dev/null @@ -1,327 +0,0 @@ -package daos - -import ( - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// CourseTagDao is the data access object for courses tags -type CourseTagDao struct { - db database.Database - table string -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NewCourseTagDao returns a new CourseTagDao -func NewCourseTagDao(db database.Database) *CourseTagDao { - return &CourseTagDao{ - db: db, - table: "courses_tags", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Table returns the table name -func (dao *CourseTagDao) Table() string { - return dao.table -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Count returns the number of course-tags -func (dao *CourseTagDao) Count(dbParams *database.DatabaseParams, tx *database.Tx) (int, error) { - generic := NewGenericDao(dao.db, dao) - return generic.Count(dbParams, nil) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Create inserts a new course-tag and tag if it does not exist -// -// If `tx` is nil, the function will create a new transaction, else it will use the current -// transaction -func (dao *CourseTagDao) Create(ct *models.CourseTag, tx *database.Tx) error { - if tx == nil { - return dao.db.RunInTransaction(func(tx *database.Tx) error { - return dao.create(ct, tx) - }) - } else { - return dao.create(ct, tx) - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// List selects courses -// -// `tx` allows for the function to be run within a transaction -func (dao *CourseTagDao) List(dbParams *database.DatabaseParams, tx *database.Tx) ([]*models.CourseTag, error) { - generic := NewGenericDao(dao.db, dao) - - if dbParams == nil { - dbParams = &database.DatabaseParams{} - } - - // Process the order by clauses - dbParams.OrderBy = dao.ProcessOrderBy(dbParams.OrderBy, false) - - // Default the columns if not specified - if len(dbParams.Columns) == 0 { - dbParams.Columns = dao.columns() - } - - rows, err := generic.List(dbParams, tx) - if err != nil { - return nil, err - } - defer rows.Close() - - var cts []*models.CourseTag - - for rows.Next() { - ct, err := dao.scanRow(rows) - if err != nil { - return nil, err - } - - cts = append(cts, ct) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return cts, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// ListCourseIdsByTags selects course ids which have all the tags. -func (dao *CourseTagDao) ListCourseIdsByTags(tags []string, dbParams *database.DatabaseParams) ([]string, error) { - if len(tags) == 0 { - return nil, nil - } - - generic := NewGenericDao(dao.db, dao) - - if dbParams == nil { - dbParams = &database.DatabaseParams{} - } - - dbParams.OrderBy = dao.ProcessOrderBy(dbParams.OrderBy, false) - dbParams.Columns = []string{dao.Table() + ".course_id"} - dbParams.Where = squirrel.Eq{NewTagDao(dao.db).Table() + ".tag": tags} - dbParams.GroupBys = []string{dao.Table() + ".course_id"} - dbParams.Having = squirrel.Expr("COUNT(DISTINCT "+NewTagDao(dao.db).Table()+".tag) = ?", len(tags)) - dbParams.Pagination = nil - - rows, err := generic.List(dbParams, nil) - if err != nil { - return nil, err - } - defer rows.Close() - - var courseIds []string - for rows.Next() { - var courseId string - if err := rows.Scan(&courseId); err != nil { - return nil, err - } - - courseIds = append(courseIds, courseId) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return courseIds, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Delete deletes a course-tag based upon the where clause -// -// `tx` allows for the function to be run within a transaction -func (dao *CourseTagDao) Delete(dbParams *database.DatabaseParams, tx *database.Tx) error { - if dbParams == nil || dbParams.Where == nil { - return ErrMissingWhere - } - - generic := NewGenericDao(dao.db, dao) - return generic.Delete(dbParams, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// ProcessOrderBy takes an array of strings representing orderBy clauses and returns a processed -// version of this array -// -// It will creates a new list of valid table columns based upon columns() for the current -// DAO -func (dao *CourseTagDao) ProcessOrderBy(orderBy []string, explicit bool) []string { - if len(orderBy) == 0 { - return orderBy - } - - generic := NewGenericDao(dao.db, dao) - return generic.ProcessOrderBy(orderBy, dao.columns(), explicit) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Internal -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// create inserts a new course-tag and tag if it does not exist -// -// This function is used by Create() and always runs within a transaction -func (dao *CourseTagDao) create(ct *models.CourseTag, tx *database.Tx) error { - if tx == nil { - return ErrNilTransaction - } - - if ct.Tag == "" { - return ErrMissingTag - } - - if ct.CourseId == "" { - return ErrMissingCourseId - } - - if ct.ID == "" { - ct.RefreshId() - } - - ct.RefreshCreatedAt() - ct.RefreshUpdatedAt() - - // Check if the tag exists. This should return 0 or 1 tags as tags are unique - tagDao := NewTagDao(dao.db) - - tags, err := tagDao.List(&database.DatabaseParams{Where: squirrel.Eq{"tag": ct.Tag}}, tx) - if err != nil { - return err - } - - // Create the tag if it doesn't exist - if len(tags) == 0 { - tag := &models.Tag{ - Tag: ct.Tag, - } - - if err := tagDao.Create(tag, tx); err != nil { - return err - } - - ct.TagId = tag.ID - } else { - ct.TagId = tags[0].ID - } - - // Insert the course-tag - query, args, _ := squirrel. - StatementBuilder. - Insert(dao.Table()). - SetMap(dao.data(ct)). - ToSql() - - _, err = tx.Exec(query, args...) - - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// countSelect returns the default count select builder -func (dao *CourseTagDao) countSelect() squirrel.SelectBuilder { - tagDao := NewTagDao(dao.db) - courseDao := NewCourseDao(dao.db) - - return squirrel. - StatementBuilder. - PlaceholderFormat(squirrel.Question). - Select(""). - From(dao.Table()). - LeftJoin(courseDao.Table() + " ON " + dao.Table() + ".course_id = " + courseDao.Table() + ".id"). - LeftJoin(tagDao.Table() + " ON " + dao.Table() + ".tag_id = " + tagDao.Table() + ".id"). - RemoveColumns() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// baseSelect returns the default select builder -// -// It performs 2 left joins -// - courses table to get `title` -// - tags table to get `tag` -// -// Note: The columns are removed, so you must specify the columns with `.Columns(...)` when using -// this select builder -func (dao *CourseTagDao) baseSelect() squirrel.SelectBuilder { - return dao.countSelect() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// columns returns the columns to select -func (dao *CourseTagDao) columns() []string { - tagDao := NewTagDao(dao.db) - courseDao := NewCourseDao(dao.db) - - return []string{ - dao.Table() + ".*", - courseDao.Table() + ".title as course", - tagDao.Table() + ".tag", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// data generates a map of key/values for a course-tag -func (dao *CourseTagDao) data(ct *models.CourseTag) map[string]any { - return map[string]any{ - "id": ct.ID, - "tag_id": NilStr(ct.TagId), - "course_id": NilStr(ct.CourseId), - "created_at": FormatTime(ct.CreatedAt), - "updated_at": FormatTime(ct.UpdatedAt), - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// scanRow scans a course-tag row -func (dao *CourseTagDao) scanRow(scannable Scannable) (*models.CourseTag, error) { - var ct models.CourseTag - - var createdAt string - var updatedAt string - - err := scannable.Scan( - &ct.ID, - &ct.TagId, - &ct.CourseId, - &createdAt, - &updatedAt, - &ct.Course, - &ct.Tag, - ) - - if err != nil { - return nil, err - } - - if ct.CreatedAt, err = ParseTime(createdAt); err != nil { - return nil, err - } - - if ct.UpdatedAt, err = ParseTime(updatedAt); err != nil { - return nil, err - } - - return &ct, nil -} diff --git a/daos/course_tag_test.go b/daos/course_tag_test.go deleted file mode 100644 index 25f99a4..0000000 --- a/daos/course_tag_test.go +++ /dev/null @@ -1,392 +0,0 @@ -package daos - -import ( - "fmt" - "testing" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils/pagination" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func courseTagSetup(t *testing.T) (*CourseTagDao, database.Database) { - t.Helper() - - dbManager := setup(t) - courseTagDao := NewCourseTagDao(dbManager.DataDb) - return courseTagDao, dbManager.DataDb -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseTag_Count(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := courseTagSetup(t) - - count, err := dao.Count(nil, nil) - require.Nil(t, err) - require.Zero(t, count) - }) - - t.Run("entries", func(t *testing.T) { - dao, db := courseTagSetup(t) - - NewTestBuilder(t).Db(db).Courses(2).Tags(6).Build() - - count, err := dao.Count(nil, nil) - require.Nil(t, err) - require.Equal(t, count, 12) - }) - - t.Run("where", func(t *testing.T) { - dao, db := courseTagSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(2).Tags([]string{"C", "Go", "Java", "TypeScript", "JavaScript"}).Build() - - courseDao := NewCourseDao(dao.db) - tagDao := NewTagDao(dao.db) - - // ---------------------------- - // EQUALS - // ---------------------------- - count, err := dao.Count(&database.DatabaseParams{Where: squirrel.Eq{courseDao.Table() + ".title": testData[1].Course.Title}}, nil) - require.Nil(t, err) - require.Equal(t, 5, count) - - // ---------------------------- - // NOT EQUALS - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.NotEq{tagDao.Table() + ".tag": "Go"}}, nil) - require.Nil(t, err) - require.Equal(t, 8, count) - - // ---------------------------- - // STARTS WITH (Java%) - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.Like{tagDao.Table() + ".tag": "Java%"}}, nil) - require.Nil(t, err) - require.Equal(t, 4, count) - - // ---------------------------- - // ERROR - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) - require.ErrorContains(t, err, "syntax error") - require.Zero(t, count) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := courseTagSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Count(nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseTag_Create(t *testing.T) { - t.Run("success (new tag)", func(t *testing.T) { - dao, db := courseTagSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - - ct := &models.CourseTag{ - CourseId: testData[0].Course.ID, - Tag: test_tags[0], - } - - // Create the course-tag. This will also create the tag - err := dao.Create(ct, nil) - require.Nil(t, err) - }) - - t.Run("success (existing tag)", func(t *testing.T) { - dao, db := courseTagSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - - // Create the tag - tagDao := NewTagDao(db) - tag := &models.Tag{ - Tag: test_tags[0], - } - require.Nil(t, tagDao.Create(tag, nil)) - - // Create the course-tag - ct := &models.CourseTag{ - TagId: tag.ID, - CourseId: testData[0].Course.ID, - Tag: tag.Tag, - } - - err := dao.Create(ct, nil) - require.Nil(t, err) - }) - - t.Run("duplicate", func(t *testing.T) { - dao, db := courseTagSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - - ct := &models.CourseTag{ - CourseId: testData[0].Course.ID, - Tag: test_tags[0], - } - - // Create the course-tag. This will also create the tag - require.Nil(t, dao.Create(ct, nil)) - - // Create the course-tag (again) - require.ErrorContains(t, dao.Create(ct, nil), fmt.Sprintf("UNIQUE constraint failed: %s.tag_id, %s.course_id", dao.Table(), dao.Table())) - }) - - t.Run("constraints", func(t *testing.T) { - dao, db := courseTagSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - // tag := "test" - - // Tag - ct := &models.CourseTag{} - require.ErrorIs(t, dao.Create(ct, nil), ErrMissingTag) - ct.Tag = "test" - - // Course ID - require.ErrorIs(t, dao.Create(ct, nil), ErrMissingCourseId) - ct.CourseId = "1234" - require.ErrorContains(t, dao.Create(ct, nil), "constraint failed: FOREIGN KEY constraint failed") - ct.CourseId = testData[0].Course.ID - - // Success - require.Nil(t, dao.Create(ct, nil)) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseTag_ListCourseIdsByTags(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := courseTagSetup(t) - - tags, err := dao.ListCourseIdsByTags([]string{"1234"}, nil) - require.Nil(t, err) - require.Zero(t, tags) - }) - - t.Run("found", func(t *testing.T) { - dao, _ := courseTagSetup(t) - - course1 := NewTestBuilder(t).Db(dao.db).Courses([]string{"course 1"}).Tags([]string{"Go", "Data Structures"}).Build()[0] - course2 := NewTestBuilder(t).Db(dao.db).Courses([]string{"course 2"}).Tags([]string{"Data Structures", "TypeScript", "PHP"}).Build()[0] - course3 := NewTestBuilder(t).Db(dao.db).Courses([]string{"course 3"}).Tags([]string{"Go", "Data Structures", "PHP"}).Build()[0] - - // Order by title (asc) - dbParams := &database.DatabaseParams{OrderBy: []string{NewCourseDao(dao.db).Table() + ".title asc"}} - - // Go - result, err := dao.ListCourseIdsByTags([]string{"Go"}, dbParams) - require.Nil(t, err) - require.Len(t, result, 2) - require.Equal(t, course1.ID, result[0]) - require.Equal(t, course3.ID, result[1]) - - // Go, Data Structures - result, err = dao.ListCourseIdsByTags([]string{"Go", "Data Structures"}, dbParams) - require.Nil(t, err) - require.Len(t, result, 2) - require.Equal(t, course1.ID, result[0]) - require.Equal(t, course3.ID, result[1]) - - // Go, Data Structures, PHP - result, err = dao.ListCourseIdsByTags([]string{"Go", "Data Structures", "PHP"}, dbParams) - require.Nil(t, err) - require.Len(t, result, 1) - require.Equal(t, course3.ID, result[0]) - - // Go, Data Structures, PHP, TypeScript - result, err = dao.ListCourseIdsByTags([]string{"Go", "Data Structures", "PHP", "TypeScript"}, dbParams) - require.Nil(t, err) - require.Len(t, result, 0) - - // Data Structures - result, err = dao.ListCourseIdsByTags([]string{"Data Structures"}, dbParams) - require.Nil(t, err) - require.Len(t, result, 3) - require.Equal(t, course1.ID, result[0]) - require.Equal(t, course2.ID, result[1]) - require.Equal(t, course3.ID, result[2]) - - }) - - t.Run("db error", func(t *testing.T) { - dao, db := courseTagSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.ListCourseIdsByTags([]string{"1234"}, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseTag_List(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := courseTagSetup(t) - - tags, err := dao.List(nil, nil) - require.Nil(t, err) - require.Zero(t, tags) - }) - - t.Run("found", func(t *testing.T) { - dao, _ := courseTagSetup(t) - - NewTestBuilder(t).Db(dao.db).Courses(2).Tags(5).Build() - - result, err := dao.List(nil, nil) - require.Nil(t, err) - require.Len(t, result, 10) - }) - - t.Run("orderby", func(t *testing.T) { - dao, _ := courseTagSetup(t) - - NewTestBuilder(t).Db(dao.db).Courses(2).Tags([]string{"PHP", "Go", "Java", "TypeScript", "JavaScript"}).Build() - tagDao := NewTagDao(dao.db) - - // ---------------------------- - // TAG DESC - // ---------------------------- - result, err := dao.List(&database.DatabaseParams{OrderBy: []string{tagDao.Table() + ".tag desc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 10) - require.Equal(t, "TypeScript", result[0].Tag) - - // ---------------------------- - // TAG ASC - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{OrderBy: []string{tagDao.Table() + ".tag asc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 10) - require.Equal(t, "Go", result[0].Tag) - - // ---------------------------- - // Error - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{OrderBy: []string{"unit_test asc"}}, nil) - require.ErrorContains(t, err, "no such column") - require.Nil(t, result) - }) - - t.Run("where", func(t *testing.T) { - dao, _ := courseTagSetup(t) - - testData := NewTestBuilder(t).Db(dao.db).Courses(2).Tags([]string{"PHP", "Go", "Java", "TypeScript", "JavaScript"}).Build() - - courseDao := NewCourseDao(dao.db) - tagDao := NewTagDao(dao.db) - - // ---------------------------- - // EQUALS (course title) - // ---------------------------- - result, err := dao.List(&database.DatabaseParams{Where: squirrel.Eq{courseDao.Table() + ".title": testData[0].Course.Title}}, nil) - require.Nil(t, err) - require.Len(t, result, 5) - - // ---------------------------- - // Like (Java%) - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{Where: squirrel.Like{tagDao.Table() + ".tag": "Java%"}}, nil) - require.Nil(t, err) - require.Len(t, result, 4) - - // ---------------------------- - // ERROR - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) - require.ErrorContains(t, err, "syntax error") - require.Nil(t, result) - }) - - t.Run("pagination", func(t *testing.T) { - dao, _ := courseTagSetup(t) - - NewTestBuilder(t).Db(dao.db).Courses(1).Tags(20).Build() - - // ---------------------------- - // Page 1 with 10 items - // ---------------------------- - p := pagination.New(1, 10) - - result, err := dao.List(&database.DatabaseParams{Pagination: p, OrderBy: []string{"tag asc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 10) - require.Equal(t, 20, p.TotalItems()) - require.Equal(t, "C", result[0].Tag) - - // ---------------------------- - // Page 2 with 10 items - // ---------------------------- - p = pagination.New(2, 10) - - result, err = dao.List(&database.DatabaseParams{Pagination: p, OrderBy: []string{"tag asc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 10) - require.Equal(t, 20, p.TotalItems()) - require.Equal(t, "Perl", result[0].Tag) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := courseTagSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.List(nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseTag_Delete(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, _ := courseTagSetup(t) - - testData := NewTestBuilder(t).Db(dao.db).Courses(1).Tags([]string{"C", "Go", "JavaScript", "Perl"}).Build() - - err := dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": testData[0].Tags[1].ID}}, nil) - require.Nil(t, err) - - tags, err := dao.List(nil, nil) - require.Nil(t, err) - require.Len(t, tags, 3) - }) - - t.Run("no db params", func(t *testing.T) { - dao, _ := courseTagSetup(t) - - err := dao.Delete(nil, nil) - require.ErrorIs(t, err, ErrMissingWhere) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := courseTagSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - err = dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"tag": "1234"}}, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} diff --git a/daos/course_test.go b/daos/course_test.go deleted file mode 100644 index ad2c437..0000000 --- a/daos/course_test.go +++ /dev/null @@ -1,600 +0,0 @@ -package daos - -import ( - "database/sql" - "fmt" - "path/filepath" - "strings" - "testing" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils/pagination" - "github.com/geerew/off-course/utils/types" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func courseSetup(t *testing.T) (*CourseDao, database.Database) { - t.Helper() - - dbManager := setup(t) - courseDao := NewCourseDao(dbManager.DataDb) - return courseDao, dbManager.DataDb -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourse_Count(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := courseSetup(t) - - count, err := dao.Count(nil) - require.Nil(t, err) - require.Zero(t, count) - }) - - t.Run("entries", func(t *testing.T) { - dao, db := courseSetup(t) - - NewTestBuilder(t).Db(db).Courses(5).Build() - - count, err := dao.Count(nil) - require.Nil(t, err) - require.Equal(t, count, 5) - }) - - t.Run("where", func(t *testing.T) { - dao, db := courseSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Build() - - // ---------------------------- - // EQUALS ID - // ---------------------------- - count, err := dao.Count(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".id": testData[2].ID}}) - require.Nil(t, err) - require.Equal(t, 1, count) - - // ---------------------------- - // NOT EQUALS ID - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.NotEq{dao.Table() + ".id": testData[2].ID}}) - require.Nil(t, err) - require.Equal(t, 2, count) - - // ---------------------------- - // ERROR - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}) - require.ErrorContains(t, err, "syntax error") - require.Zero(t, count) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := courseSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Count(nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourse_Create(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, _ := courseSetup(t) - - testData := NewTestBuilder(t).Courses(1).Build() - - err := dao.Create(testData[0].Course) - require.Nil(t, err) - - newC, err := dao.Get(testData[0].ID, nil, nil) - require.Nil(t, err) - require.NotEmpty(t, newC.ID) - require.Equal(t, testData[0].Title, newC.Title) - require.Equal(t, testData[0].Path, newC.Path) - require.Empty(t, newC.CardPath) - require.False(t, newC.Available) - require.False(t, newC.CreatedAt.IsZero()) - require.False(t, newC.UpdatedAt.IsZero()) - //Scan status - require.Empty(t, newC.ScanStatus) - // Progress - require.False(t, newC.Started) - require.True(t, newC.StartedAt.IsZero()) - require.Zero(t, newC.Percent) - require.True(t, newC.CompletedAt.IsZero()) - }) - - t.Run("duplicate paths", func(t *testing.T) { - dao, _ := courseSetup(t) - - testData := NewTestBuilder(t).Courses(1).Build() - - err := dao.Create(testData[0].Course) - require.Nil(t, err) - - err = dao.Create(testData[0].Course) - require.ErrorContains(t, err, fmt.Sprintf("UNIQUE constraint failed: %s.path", dao.Table())) - }) - - t.Run("constraints", func(t *testing.T) { - dao, _ := courseSetup(t) - - // No title - c := &models.Course{} - require.ErrorContains(t, dao.Create(c), fmt.Sprintf("NOT NULL constraint failed: %s.title", dao.Table())) - c.Title = "" - require.ErrorContains(t, dao.Create(c), fmt.Sprintf("NOT NULL constraint failed: %s.title", dao.Table())) - c.Title = "Course 1" - - // No path - require.ErrorContains(t, dao.Create(c), fmt.Sprintf("NOT NULL constraint failed: %s.path", dao.Table())) - c.Path = "" - require.ErrorContains(t, dao.Create(c), fmt.Sprintf("NOT NULL constraint failed: %s.path", dao.Table())) - c.Path = "/course 1" - - // Success - require.Nil(t, dao.Create(c)) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourse_Get(t *testing.T) { - t.Run("found", func(t *testing.T) { - dao, db := courseSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(2).Assets(1).Build() - - c, err := dao.Get(testData[1].ID, nil, nil) - require.Nil(t, err) - require.Equal(t, testData[1].ID, c.ID) - require.Empty(t, testData[1].ScanStatus) - - // ---------------------------- - // scan - // ---------------------------- - scanDao := NewScanDao(db) - require.Nil(t, scanDao.Create(&models.Scan{CourseID: testData[1].ID}, nil)) - - c, err = dao.Get(testData[1].ID, nil, nil) - require.Nil(t, err) - require.Equal(t, string(types.ScanStatusWaiting), c.ScanStatus) - - // ---------------------------- - // Availability - // ---------------------------- - require.False(t, c.Available) - - // Set to started - testData[1].Available = true - require.Nil(t, dao.Update(testData[1].Course, nil)) - - c, err = dao.Get(testData[1].ID, nil, nil) - require.Nil(t, err) - require.True(t, c.Available) - - // ---------------------------- - // Progress - // ---------------------------- - require.False(t, c.Started) - require.True(t, c.StartedAt.IsZero()) - require.Zero(t, c.Percent) - require.True(t, c.CompletedAt.IsZero()) - }) - - t.Run("not found", func(t *testing.T) { - dao, _ := courseSetup(t) - - c, err := dao.Get("1234", nil, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, c) - }) - - t.Run("empty id", func(t *testing.T) { - dao, _ := courseSetup(t) - - c, err := dao.Get("", nil, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, c) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := courseSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Get("1234", nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourse_List(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := courseSetup(t) - - courses, err := dao.List(nil, nil) - require.Nil(t, err) - require.Zero(t, courses) - }) - - t.Run("found", func(t *testing.T) { - dao, db := courseSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(5).Assets(2).Build() - - result, err := dao.List(nil, nil) - require.Nil(t, err) - require.Len(t, result, 5) - - // ---------------------------- - // Scan - // ---------------------------- - require.Empty(t, result[1].ScanStatus) - - // Create a scan for course 2 - scanDao := NewScanDao(db) - require.Nil(t, scanDao.Create(&models.Scan{CourseID: testData[1].ID}, nil)) - - result, err = dao.List(nil, nil) - require.Nil(t, err) - require.Equal(t, testData[1].ID, result[1].ID) - require.Equal(t, string(types.ScanStatusWaiting), result[1].ScanStatus) - - // ---------------------------- - // Availability - // ---------------------------- - for _, c := range result { - require.False(t, c.Available) - } - - // Set course 1 as available - testData[0].Available = true - require.Nil(t, dao.Update(testData[0].Course, nil)) - - // Find available courses - result, err = dao.List(&database.DatabaseParams{Where: squirrel.And{squirrel.Eq{dao.Table() + ".available": true}}}, nil) - require.Nil(t, err) - require.Len(t, result, 1) - require.Equal(t, testData[0].ID, result[0].ID) - - // ---------------------------- - // Progress - // ---------------------------- - apDao := NewAssetProgressDao(db) - cpDao := NewCourseProgressDao(db) - - for _, c := range result { - require.False(t, c.Started) - require.True(t, c.StartedAt.IsZero()) - require.Zero(t, c.Percent) - require.True(t, c.CompletedAt.IsZero()) - } - - // Create progress for asset 1 in course 1 and set the video position to 50 - ap1 := &models.AssetProgress{AssetID: testData[0].Assets[0].ID, CourseID: testData[0].ID, VideoPos: 50} - - require.Nil(t, apDao.Create(ap1, nil)) - - // Find in-progress courses - dbParams := &database.DatabaseParams{ - Where: squirrel.And{ - squirrel.Eq{cpDao.Table() + ".started": true}, - squirrel.NotEq{cpDao.Table() + ".percent": 100}, - }, - } - - result, err = dao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, result, 1) - require.Equal(t, testData[0].ID, result[0].ID) - - // Set progress for asset 1 in course 1 as complete - ap1.Completed = true - require.Nil(t, apDao.Update(ap1, nil)) - - // Create progress for asset 2 in course 1 and set completed to true - ap2 := &models.AssetProgress{AssetID: testData[0].Assets[1].ID, CourseID: testData[0].ID, Completed: true} - require.Nil(t, apDao.Create(ap2, nil)) - - // Find completed courses - result, err = dao.List(&database.DatabaseParams{Where: squirrel.Eq{cpDao.Table() + ".percent": 100}}, nil) - require.Nil(t, err) - require.Len(t, result, 1) - require.Equal(t, testData[0].ID, result[0].ID) - }) - - t.Run("orderby", func(t *testing.T) { - dao, db := courseSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Build() - - // ---------------------------- - // CREATED_AT DESC - // ---------------------------- - dbParams := &database.DatabaseParams{OrderBy: []string{"created_at desc"}} - result, err := dao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, result, 3) - require.Equal(t, testData[2].ID, result[0].ID) - - // ---------------------------- - // CREATED_AT ASC - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{OrderBy: []string{"created_at asc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 3) - require.Equal(t, testData[0].ID, result[0].ID) - - // ---------------------------- - // SCAN_STATUS DESC - // ---------------------------- - - // Create a scan for course 2 and 3 - scanDao := NewScanDao(db) - - testData[1].Scan = &models.Scan{CourseID: testData[1].ID} - require.Nil(t, scanDao.Create(testData[1].Scan, nil)) - testData[2].Scan = &models.Scan{CourseID: testData[2].ID} - require.Nil(t, scanDao.Create(testData[2].Scan, nil)) - - // Set course 3 to processing - testData[2].Scan.Status = types.NewScanStatus(types.ScanStatusProcessing) - require.Nil(t, scanDao.Update(testData[2].Scan, nil)) - - result, err = dao.List(&database.DatabaseParams{OrderBy: []string{scanDao.Table() + ".status desc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 3) - - require.Equal(t, testData[0].ID, result[2].ID) - require.Equal(t, testData[1].ID, result[1].ID) - require.Equal(t, testData[2].ID, result[0].ID) - - // ---------------------------- - // SCAN_STATUS ASC - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{OrderBy: []string{scanDao.Table() + ".status asc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 3) - - require.Equal(t, testData[0].ID, result[0].ID) - require.Equal(t, testData[1].ID, result[1].ID) - require.Equal(t, testData[2].ID, result[2].ID) - - // ---------------------------- - // Error - // ---------------------------- - dbParams = &database.DatabaseParams{OrderBy: []string{"unit_test asc"}} - result, err = dao.List(dbParams, nil) - require.ErrorContains(t, err, "no such column") - require.Nil(t, result) - }) - - t.Run("where", func(t *testing.T) { - dao, db := courseSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Build() - - // ---------------------------- - // EQUALS ID - // ---------------------------- - result, err := dao.List(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".id": testData[2].ID}}, nil) - require.Nil(t, err) - require.Len(t, result, 1) - require.Equal(t, testData[2].ID, result[0].ID) - - // ---------------------------- - // EQUALS ID OR ID - // ---------------------------- - dbParams := &database.DatabaseParams{ - Where: squirrel.Or{squirrel.Eq{dao.Table() + ".id": testData[1].ID}, squirrel.Eq{dao.Table() + ".id": testData[2].ID}}, - OrderBy: []string{"created_at asc"}, - } - result, err = dao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, result, 2) - require.Equal(t, testData[1].ID, result[0].ID) - require.Equal(t, testData[2].ID, result[1].ID) - - // ---------------------------- - // ERROR - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) - require.ErrorContains(t, err, "syntax error") - require.Nil(t, result) - }) - - t.Run("pagination", func(t *testing.T) { - dao, db := courseSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(17).Build() - - // ---------------------------- - // Page 1 with 10 items - // ---------------------------- - p := pagination.New(1, 10) - - result, err := dao.List(&database.DatabaseParams{Pagination: p}, nil) - require.Nil(t, err) - require.Len(t, result, 10) - require.Equal(t, 17, p.TotalItems()) - require.Equal(t, testData[0].ID, result[0].ID) - require.Equal(t, testData[9].ID, result[9].ID) - - // ---------------------------- - // Page 2 with 7 items - // ---------------------------- - p = pagination.New(2, 10) - - result, err = dao.List(&database.DatabaseParams{Pagination: p}, nil) - require.Nil(t, err) - require.Len(t, result, 7) - require.Equal(t, 17, p.TotalItems()) - require.Equal(t, testData[10].ID, result[0].ID) - require.Equal(t, testData[16].ID, result[6].ID) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := courseSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.List(nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourse_Update(t *testing.T) { - t.Run("card path", func(t *testing.T) { - dao, db := courseSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - require.Empty(t, testData[0].CardPath) - - // Update the card path - testData[0].CardPath = "/path/to/card.jpg" - require.Nil(t, dao.Update(testData[0].Course, nil)) - - c, err := dao.Get(testData[0].ID, nil, nil) - require.Nil(t, err) - require.Equal(t, testData[0].CardPath, c.CardPath) - }) - - t.Run("available", func(t *testing.T) { - dao, db := courseSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - require.False(t, testData[0].Available) - - // Update the availability - testData[0].Available = true - require.Nil(t, dao.Update(testData[0].Course, nil)) - - c, err := dao.Get(testData[0].ID, nil, nil) - require.Nil(t, err) - require.True(t, c.Available) - }) - - t.Run("empty id", func(t *testing.T) { - dao, _ := courseSetup(t) - - err := dao.Update(&models.Course{}, nil) - require.ErrorIs(t, err, ErrEmptyId) - }) - - t.Run("invalid id", func(t *testing.T) { - dao, db := courseSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - testData[0].ID = "1234" - require.Nil(t, dao.Update(testData[0].Course, nil)) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := courseSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - testData := NewTestBuilder(t).Courses(1).Build() - - err = dao.Update(testData[0].Course, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourse_Delete(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, db := courseSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - - err := dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": testData[0].ID}}, nil) - require.Nil(t, err) - }) - - t.Run("no db params", func(t *testing.T) { - dao, _ := courseSetup(t) - - err := dao.Delete(nil, nil) - require.ErrorIs(t, err, ErrMissingWhere) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := courseSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - err = dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": "1234"}}, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourse_ClassifyPaths(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, db := courseSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Build() - - path1 := string(filepath.Separator) // ancestor - path2 := strings.TrimSuffix(testData[1].Path, string(filepath.Separator)+"Course 2") // ancestor - path3 := string(filepath.Separator) + "test" // none - path4 := testData[2].Path // course - path5 := filepath.Join(testData[2].Path + "test") // descendant - - result, err := dao.ClassifyPaths([]string{path1, path2, path3, path4, path5}) - require.Nil(t, err) - - require.Equal(t, types.PathClassificationAncestor, result[path1]) - require.Equal(t, types.PathClassificationAncestor, result[path2], fmt.Sprintf("path2: %s, result: %d", path2, result[path2])) - require.Equal(t, types.PathClassificationNone, result[path3]) - require.Equal(t, types.PathClassificationCourse, result[path4]) - require.Equal(t, types.PathClassificationDescendant, result[path5]) - }) - - t.Run("no paths", func(t *testing.T) { - dao, _ := courseSetup(t) - - result, err := dao.ClassifyPaths([]string{}) - require.Nil(t, err) - require.Empty(t, result) - }) - - t.Run("empty path", func(t *testing.T) { - dao, _ := courseSetup(t) - - result, err := dao.ClassifyPaths([]string{"", "", ""}) - require.Nil(t, err) - require.Empty(t, result) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := courseSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - result, err := dao.ClassifyPaths([]string{"/"}) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - require.Empty(t, result) - }) -} diff --git a/daos/generic.go b/daos/generic.go deleted file mode 100644 index 9a253fe..0000000 --- a/daos/generic.go +++ /dev/null @@ -1,199 +0,0 @@ -package daos - -import ( - "database/sql" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// GenericDao is the data access object for generic queries -type GenericDao struct { - db database.Database - caller daoer -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NewGenericDao returns a new GenericDao -func NewGenericDao(db database.Database, caller daoer) *GenericDao { - return &GenericDao{ - db: db, - caller: caller, - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Count returns the number of rows in a table -// -// `tx` allows for the function to be run within a transaction -func (dao *GenericDao) Count(dbParams *database.DatabaseParams, tx *database.Tx) (int, error) { - queryRowFn := dao.db.QueryRow - if tx != nil { - queryRowFn = tx.QueryRow - } - - builder := dao.caller.countSelect(). - Columns("COUNT(DISTINCT " + dao.caller.Table() + ".id)") - - if dbParams != nil && dbParams.Where != nil { - builder = builder.Where(dbParams.Where) - } - - query, args, _ := builder.ToSql() - - var count int - err := queryRowFn(query, args...).Scan(&count) - - return count, err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Get returns a row from a table -// -// `tx` allows for the function to be run within a transaction -func (dao *GenericDao) Get(dbParams *database.DatabaseParams, tx *database.Tx) (*sql.Row, error) { - queryRowFn := dao.db.QueryRow - if tx != nil { - queryRowFn = tx.QueryRow - } - - if dbParams == nil || dbParams.Where == nil { - return nil, ErrMissingWhere - } - - builder := dao.caller.baseSelect() - - if dbParams.Columns == nil { - builder = builder.Columns(dao.caller.Table() + ".*") - } else { - builder = builder.Columns(dbParams.Columns...) - } - - if dbParams.OrderBy != nil { - builder = builder.OrderBy(dbParams.OrderBy...) - } - - query, args, _ := builder. - Where(dbParams.Where). - ToSql() - - row := queryRowFn(query, args...) - if row.Err() != nil { - return nil, row.Err() - } - - return row, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// List returns rows from a table -// -// `tx` allows for the function to be run within a transaction -func (dao *GenericDao) List(dbParams *database.DatabaseParams, tx *database.Tx) (*sql.Rows, error) { - queryFn := dao.db.Query - if tx != nil { - queryFn = tx.Query - } - - builder := dao.caller.baseSelect() - - if dbParams != nil { - if dbParams.Columns == nil { - builder = builder.Columns(dao.caller.Table() + ".*") - } else { - builder = builder.Columns(dbParams.Columns...) - } - - if dbParams.Where != "" { - builder = builder.Where(dbParams.Where) - } - - if dbParams.OrderBy != nil { - builder = builder.OrderBy(dbParams.OrderBy...) - } - - if dbParams.Pagination != nil { - if count, err := dao.Count(dbParams, tx); err != nil { - return nil, err - } else { - dbParams.Pagination.SetCount(count) - builder = builder. - Offset(uint64(dbParams.Pagination.Offset())). - Limit(uint64(dbParams.Pagination.Limit())) - } - } - - if dbParams.GroupBys != nil { - builder = builder.GroupBy(dbParams.GroupBys...) - } - - if dbParams.Having != nil { - builder = builder.Having(dbParams.Having) - } - } - - query, args, _ := builder.ToSql() - - return queryFn(query, args...) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Delete deletes a row from a table based upon a where clause -// -// `tx` allows for the function to be run within a transaction -func (dao *GenericDao) Delete(dbParams *database.DatabaseParams, tx *database.Tx) error { - if dbParams == nil || dbParams.Where == nil { - return ErrMissingWhere - } - - execFn := dao.db.Exec - if tx != nil { - execFn = tx.Exec - } - - query, args, _ := squirrel. - StatementBuilder. - PlaceholderFormat(squirrel.Question). - Delete(dao.caller.Table()). - Where(dbParams.Where). - ToSql() - - _, err := execFn(query, args...) - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// ProcessOrderBy takes an array of strings representing orderBy clauses and returns a processed -// version of this array -// -// It will creates a new list of valid Table columns based upon columns() for the current -// DAO -func (dao *GenericDao) ProcessOrderBy(orderBy []string, validColumns []string, explicit bool) []string { - if len(orderBy) == 0 { - return orderBy - } - - var processedOrderBy []string - - for _, ob := range orderBy { - Table, column := extractTableColumn(ob) - - if explicit && Table == "" { - continue - } - - if isValidOrderBy(Table, column, validColumns) { - processedOrderBy = append(processedOrderBy, ob) - } - } - - return processedOrderBy -} diff --git a/daos/log.go b/daos/log.go deleted file mode 100644 index 75f060e..0000000 --- a/daos/log.go +++ /dev/null @@ -1,202 +0,0 @@ -package daos - -import ( - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// LogDao is the data access object for logs -type LogDao struct { - db database.Database - table string -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NewLogDao returns a new LogDao -func NewLogDao(db database.Database) *LogDao { - return &LogDao{ - db: db, - table: "logs", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Table returns the table name -func (dao *LogDao) Table() string { - return dao.table -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Count returns the number of logs -func (dao *LogDao) Count(params *database.DatabaseParams, tx *database.Tx) (int, error) { - generic := NewGenericDao(dao.db, dao) - return generic.Count(params, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Write inserts a new log -func (dao *LogDao) Write(l *models.Log, tx *database.Tx) error { - if l.ID == "" { - l.RefreshId() - } - - l.RefreshCreatedAt() - l.UpdatedAt = l.CreatedAt - - query, args, _ := squirrel. - StatementBuilder. - Insert(dao.Table()). - SetMap(dao.data(l)). - ToSql() - - execFn := dao.db.Exec - if tx != nil { - execFn = tx.Exec - } - - _, err := execFn(query, args...) - - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// List selects logs -// -// `tx` allows for the function to be run within a transaction -func (dao *LogDao) List(dbParams *database.DatabaseParams, tx *database.Tx) ([]*models.Log, error) { - generic := NewGenericDao(dao.db, dao) - - if dbParams == nil { - dbParams = &database.DatabaseParams{} - } - - // Always override the order by to created_at - dbParams.OrderBy = []string{dao.Table() + ".created_at DESC"} - - // Default the columns if not specified - if len(dbParams.Columns) == 0 { - dbParams.Columns = dao.columns() - } - - rows, err := generic.List(dbParams, tx) - if err != nil { - return nil, err - } - defer rows.Close() - - var logs []*models.Log - - for rows.Next() { - log, err := dao.scanRow(rows) - if err != nil { - return nil, err - } - - logs = append(logs, log) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return logs, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Delete deletes logs based upon the where clause -// -// `tx` allows for the function to be run within a transaction -func (dao *LogDao) Delete(dbParams *database.DatabaseParams, tx *database.Tx) error { - if dbParams == nil || dbParams.Where == nil { - return ErrMissingWhere - } - - generic := NewGenericDao(dao.db, dao) - return generic.Delete(dbParams, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Internal -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// countSelect returns the default count select builder -func (dao *LogDao) countSelect() squirrel.SelectBuilder { - return squirrel. - StatementBuilder. - PlaceholderFormat(squirrel.Question). - Select(""). - From(dao.Table()). - RemoveColumns() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// baseSelect returns the default select builder -func (dao *LogDao) baseSelect() squirrel.SelectBuilder { - return dao.countSelect() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// columns returns the columns to select -func (dao *LogDao) columns() []string { - return []string{ - dao.Table() + ".*", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// data generates a map of key/values for a log -func (dao *LogDao) data(a *models.Log) map[string]any { - return map[string]any{ - "id": a.ID, - "level": a.Level, - "message": NilStr(a.Message), - "data": a.Data, - "created_at": FormatTime(a.CreatedAt), - "updated_at": FormatTime(a.UpdatedAt), - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// scanRow scans a log row -func (dao *LogDao) scanRow(scannable Scannable) (*models.Log, error) { - var l models.Log - - var createdAt string - var updatedAt string - - err := scannable.Scan( - &l.ID, - &l.Level, - &l.Message, - &l.Data, - &createdAt, - &updatedAt, - ) - - if err != nil { - return nil, err - } - - if l.CreatedAt, err = ParseTime(createdAt); err != nil { - return nil, err - } - - if l.UpdatedAt, err = ParseTime(updatedAt); err != nil { - return nil, err - } - - return &l, nil -} diff --git a/daos/log_test.go b/daos/log_test.go deleted file mode 100644 index 8d3ebe8..0000000 --- a/daos/log_test.go +++ /dev/null @@ -1,248 +0,0 @@ -package daos - -import ( - "fmt" - "testing" - "time" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils/pagination" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func logSetup(t *testing.T) (*LogDao, database.Database) { - t.Helper() - - dbManager := setup(t) - logDao := NewLogDao(dbManager.LogsDb) - return logDao, dbManager.LogsDb -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestLog_Count(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := logSetup(t) - - count, err := dao.Count(nil, nil) - require.Nil(t, err) - require.Zero(t, count) - }) - - t.Run("entries", func(t *testing.T) { - dao, _ := logSetup(t) - - for i := range 2 { - require.Nil(t, dao.Write(&models.Log{Data: map[string]any{}, Level: 0, Message: fmt.Sprintf("log %d", i+1)}, nil)) - } - - count, err := dao.Count(nil, nil) - require.Nil(t, err) - require.Equal(t, count, 2) - }) - - t.Run("where", func(t *testing.T) { - dao, _ := logSetup(t) - - for i := range 2 { - require.Nil(t, dao.Write(&models.Log{Data: map[string]any{}, Level: 0, Message: fmt.Sprintf("log %d", i+1)}, nil)) - } - - // ---------------------------- - // EQUALS ID - // ---------------------------- - count, err := dao.Count(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".message": "log 1"}}, nil) - require.Nil(t, err) - require.Equal(t, 1, count) - - // ---------------------------- - // NOT EQUALS ID - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.NotEq{dao.Table() + ".message": "log 1"}}, nil) - require.Nil(t, err) - require.Equal(t, 1, count) - - // ---------------------------- - // ERROR - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) - require.ErrorContains(t, err, "syntax error") - require.Zero(t, count) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := logSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Count(nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestLog_Write(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, _ := logSetup(t) - - for i := range 2 { - require.Nil(t, dao.Write(&models.Log{Data: map[string]any{}, Level: 0, Message: fmt.Sprintf("log %d", i+1)}, nil)) - } - - result, err := dao.List(nil, nil) - require.Nil(t, err) - require.Len(t, result, 2) - }) - - t.Run("constraints", func(t *testing.T) { - dao, _ := logSetup(t) - - // No message - log := &models.Log{} - require.ErrorContains(t, dao.Write(log, nil), fmt.Sprintf("NOT NULL constraint failed: %s.message", dao.Table())) - log.Message = "" - require.ErrorContains(t, dao.Write(log, nil), fmt.Sprintf("NOT NULL constraint failed: %s.message", dao.Table())) - log.Message = "Log 1" - - // Success - require.Nil(t, dao.Write(log, nil)) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestLog_List(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := logSetup(t) - - courses, err := dao.List(nil, nil) - require.Nil(t, err) - require.Zero(t, courses) - }) - - t.Run("found", func(t *testing.T) { - dao, _ := logSetup(t) - - for i := range 5 { - require.Nil(t, dao.Write(&models.Log{Data: map[string]any{}, Level: 0, Message: fmt.Sprintf("log %d", i+1)}, nil)) - time.Sleep(1 * time.Millisecond) - } - - result, err := dao.List(nil, nil) - require.Nil(t, err) - require.Len(t, result, 5) - require.Equal(t, "log 5", result[0].Message) - require.Equal(t, "log 1", result[4].Message) - }) - - t.Run("where", func(t *testing.T) { - dao, _ := logSetup(t) - - for i := range 5 { - require.Nil(t, dao.Write(&models.Log{Data: map[string]any{}, Level: 0, Message: fmt.Sprintf("log %d", i+1)}, nil)) - time.Sleep(1 * time.Millisecond) - } - - // ---------------------------- - // EQUALS log 2 or log 3 - // ---------------------------- - result, err := dao.List( - &database.DatabaseParams{Where: squirrel.Or{ - squirrel.Eq{dao.Table() + ".message": "log 2"}, - squirrel.Eq{dao.Table() + ".message": "log 3"}}}, - nil) - require.Nil(t, err) - require.Len(t, result, 2) - require.Equal(t, "log 3", result[0].Message) - require.Equal(t, "log 2", result[1].Message) - - // ---------------------------- - // ERROR - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) - require.ErrorContains(t, err, "syntax error") - require.Nil(t, result) - }) - - t.Run("pagination", func(t *testing.T) { - dao, _ := logSetup(t) - - for i := range 17 { - require.Nil(t, dao.Write(&models.Log{Data: map[string]any{}, Level: 0, Message: fmt.Sprintf("log %d", i+1)}, nil)) - time.Sleep(1 * time.Millisecond) - } - - // ---------------------------- - // Page 1 with 10 items - // ---------------------------- - p := pagination.New(1, 10) - - result, err := dao.List(&database.DatabaseParams{Pagination: p}, nil) - require.Nil(t, err) - require.Len(t, result, 10) - require.Equal(t, 17, p.TotalItems()) - require.Equal(t, "log 17", result[0].Message) - require.Equal(t, "log 8", result[9].Message) - - // ---------------------------- - // Page 2 with 7 items - // ---------------------------- - p = pagination.New(2, 10) - - result, err = dao.List(&database.DatabaseParams{Pagination: p}, nil) - require.Nil(t, err) - require.Len(t, result, 7) - require.Equal(t, 17, p.TotalItems()) - require.Equal(t, "log 7", result[0].Message) - require.Equal(t, "log 1", result[6].Message) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := logSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.List(nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestLog_Delete(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, _ := logSetup(t) - - for i := range 3 { - require.Nil(t, dao.Write(&models.Log{Data: map[string]any{}, Level: 0, Message: fmt.Sprintf("log %d", i+1)}, nil)) - time.Sleep(1 * time.Millisecond) - } - - err := dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"message": "log 2"}}, nil) - require.Nil(t, err) - }) - - t.Run("no db params", func(t *testing.T) { - dao, _ := logSetup(t) - - err := dao.Delete(nil, nil) - require.ErrorIs(t, err, ErrMissingWhere) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := logSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - err = dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": "1234"}}, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} diff --git a/daos/scan.go b/daos/scan.go deleted file mode 100644 index 1022af6..0000000 --- a/daos/scan.go +++ /dev/null @@ -1,253 +0,0 @@ -package daos - -import ( - "database/sql" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils/types" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// ScanDao is the data access object for scans -type ScanDao struct { - db database.Database - table string -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NewScanDao returns a new ScanDao -func NewScanDao(db database.Database) *ScanDao { - return &ScanDao{ - db: db, - table: "scans", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Table returns the table name -func (dao *ScanDao) Table() string { - return dao.table -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Create inserts a new scan -func (dao *ScanDao) Create(s *models.Scan, tx *database.Tx) error { - if s.ID == "" { - s.RefreshId() - } - - if s.Status.String() == "" { - s.Status = types.NewScanStatus(types.ScanStatusWaiting) - } - - s.RefreshCreatedAt() - s.RefreshUpdatedAt() - - // Default status to waiting - s.Status = types.NewScanStatus(types.ScanStatusWaiting) - - query, args, _ := squirrel. - StatementBuilder. - Insert(dao.Table()). - SetMap(dao.data(s)). - ToSql() - - execFn := dao.db.Exec - if tx != nil { - execFn = tx.Exec - } - - _, err := execFn(query, args...) - - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Get selects a scan with the given course ID -func (dao *ScanDao) Get(courseId string, tx *database.Tx) (*models.Scan, error) { - generic := NewGenericDao(dao.db, dao) - - dbParams := &database.DatabaseParams{ - Columns: dao.columns(), - Where: squirrel.Eq{dao.Table() + ".course_id": courseId}, - } - - row, err := generic.Get(dbParams, tx) - if err != nil { - return nil, err - } - - scan, err := dao.scanRow(row) - if err != nil { - return nil, err - } - - return scan, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Update updates a scan -// -// Note: Only the `status` can be updated -func (dao *ScanDao) Update(scan *models.Scan, tx *database.Tx) error { - if scan.ID == "" { - return ErrEmptyId - } - - scan.RefreshUpdatedAt() - - query, args, _ := squirrel. - StatementBuilder. - Update(dao.Table()). - Set("status", NilStr(scan.Status.String())). - Set("updated_at", FormatTime(scan.UpdatedAt)). - Where("id = ?", scan.ID). - ToSql() - - execFn := dao.db.Exec - if tx != nil { - execFn = tx.Exec - } - - _, err := execFn(query, args...) - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Delete deletes a scan based upon the where clause -// -// `tx` allows for the function to be run within a transaction -func (dao *ScanDao) Delete(dbParams *database.DatabaseParams, tx *database.Tx) error { - if dbParams == nil || dbParams.Where == nil { - return ErrMissingWhere - } - - generic := NewGenericDao(dao.db, dao) - return generic.Delete(dbParams, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Next returns the next scan whose status is `waiting“ -func (dao *ScanDao) Next(tx *database.Tx) (*models.Scan, error) { - generic := NewGenericDao(dao.db, dao) - - dbParams := &database.DatabaseParams{ - Columns: dao.columns(), - Where: squirrel.Eq{dao.Table() + ".status": types.ScanStatusWaiting}, - OrderBy: []string{"created_at ASC"}, - } - - row, err := generic.Get(dbParams, tx) - if err != nil { - return nil, err - } - - scan, err := dao.scanRow(row) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - - return nil, err - } - - return scan, nil - -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Internal -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// countSelect returns the default count select builder -func (dao *ScanDao) countSelect() squirrel.SelectBuilder { - courseDao := NewCourseDao(dao.db) - - return squirrel.StatementBuilder. - PlaceholderFormat(squirrel.Question). - Select(""). - From(dao.Table()). - LeftJoin(courseDao.Table() + " ON " + dao.Table() + ".course_id = " + courseDao.Table() + ".id"). - RemoveColumns() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// baseSelect returns the default select builder -// -// It performs 1 left join -// - courses table to get `course_path` -// -// Note: The columns are removed, so you must specify the columns with `.Columns(...)` when using -// this select builder -func (dao *ScanDao) baseSelect() squirrel.SelectBuilder { - return dao.countSelect() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// columns returns the columns to select -func (dao *ScanDao) columns() []string { - courseDao := NewCourseDao(dao.db) - - return []string{ - dao.Table() + ".*", - courseDao.Table() + ".path AS course_path", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// data generates a map of key/values for a scan -func (dao *ScanDao) data(s *models.Scan) map[string]any { - return map[string]any{ - "id": s.ID, - "course_id": NilStr(s.CourseID), - "status": NilStr(s.Status.String()), - "created_at": FormatTime(s.CreatedAt), - "updated_at": FormatTime(s.UpdatedAt), - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// scanRow scans a scan row -func (dao *ScanDao) scanRow(scannable Scannable) (*models.Scan, error) { - var s models.Scan - - var createdAt string - var updatedAt string - - err := scannable.Scan( - &s.ID, - &s.CourseID, - &s.Status, - &createdAt, - &updatedAt, - &s.CoursePath, - ) - - if err != nil { - return nil, err - } - - if s.CreatedAt, err = ParseTime(createdAt); err != nil { - return nil, err - } - - if s.UpdatedAt, err = ParseTime(updatedAt); err != nil { - return nil, err - } - - return &s, nil -} diff --git a/daos/scan_test.go b/daos/scan_test.go deleted file mode 100644 index 5a7e151..0000000 --- a/daos/scan_test.go +++ /dev/null @@ -1,276 +0,0 @@ -package daos - -import ( - "database/sql" - "fmt" - "testing" - "time" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils/types" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func scanSetup(t *testing.T) (*ScanDao, database.Database) { - t.Helper() - - dbManager := setup(t) - scanDao := NewScanDao(dbManager.DataDb) - return scanDao, dbManager.DataDb -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestScan_Create(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, db := scanSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - s := &models.Scan{CourseID: testData[0].Course.ID} - - err := dao.Create(s, nil) - require.Nil(t, err, "Failed to create scan") - - newS, err := dao.Get(s.CourseID, nil) - require.Nil(t, err) - require.Equal(t, s.ID, newS.ID) - require.True(t, newS.Status.IsWaiting()) - require.False(t, newS.CreatedAt.IsZero()) - require.False(t, newS.UpdatedAt.IsZero()) - - }) - - t.Run("duplicate course id", func(t *testing.T) { - dao, db := scanSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - s := &models.Scan{CourseID: testData[0].Course.ID} - - err := dao.Create(s, nil) - require.Nil(t, err) - - err = dao.Create(s, nil) - require.ErrorContains(t, err, fmt.Sprintf("UNIQUE constraint failed: %s.course_id", dao.Table())) - }) - - t.Run("constraint errors", func(t *testing.T) { - dao, db := scanSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Build() - - // Missing course ID - s := &models.Scan{} - require.ErrorContains(t, dao.Create(s, nil), fmt.Sprintf("NOT NULL constraint failed: %s.course_id", dao.Table())) - s.CourseID = "" - require.ErrorContains(t, dao.Create(s, nil), fmt.Sprintf("NOT NULL constraint failed: %s.course_id", dao.Table())) - s.CourseID = "1234" - - // Invalid Course ID - require.ErrorContains(t, dao.Create(s, nil), "FOREIGN KEY constraint failed") - s.CourseID = testData[0].Course.ID - - // Success - require.Nil(t, dao.Create(s, nil)) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestScan_Get(t *testing.T) { - t.Run("found", func(t *testing.T) { - dao, db := scanSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Scan().Build() - - s, err := dao.Get(testData[0].Course.ID, nil) - require.Nil(t, err) - require.Equal(t, testData[0].Scan.ID, s.ID) - require.Equal(t, testData[0].Course.Path, s.CoursePath) - }) - - t.Run("not found", func(t *testing.T) { - dao, _ := scanSetup(t) - - s, err := dao.Get("1234", nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, s) - }) - - t.Run("empty id", func(t *testing.T) { - dao, _ := scanSetup(t) - - s, err := dao.Get("", nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, s) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := scanSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Get("1234", nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestScan_Update(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, db := scanSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Scan().Build() - require.True(t, testData[0].Scan.Status.IsWaiting()) - - // ---------------------------- - // Set to Processing - // ---------------------------- - testData[0].Scan.Status = types.NewScanStatus(types.ScanStatusProcessing) - require.Nil(t, dao.Update(testData[0].Scan, nil)) - - updatedScan, err := dao.Get(testData[0].Course.ID, nil) - require.Nil(t, err) - require.False(t, updatedScan.Status.IsWaiting()) - - // ---------------------------- - // Set to waiting - // ---------------------------- - time.Sleep(1 * time.Millisecond) - testData[0].Scan.Status = types.NewScanStatus(types.ScanStatusWaiting) - require.Nil(t, dao.Update(testData[0].Scan, nil)) - - updatedScan, err = dao.Get(testData[0].Course.ID, nil) - require.Nil(t, err) - require.True(t, updatedScan.Status.IsWaiting()) - }) - - t.Run("empty id", func(t *testing.T) { - dao, _ := scanSetup(t) - - err := dao.Update(&models.Scan{}, nil) - require.ErrorIs(t, err, ErrEmptyId) - }) - - t.Run("invalid id", func(t *testing.T) { - dao, db := scanSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Scan().Build() - testData[0].Scan.ID = "1234" - - err := dao.Update(testData[0].Scan, nil) - require.Nil(t, err) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := scanSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - testData := NewTestBuilder(t).Courses(1).Scan().Build() - err = dao.Update(testData[0].Scan, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestScan_Delete(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, db := scanSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Scan().Build() - err := dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": testData[0].ID}}, nil) - require.Nil(t, err) - }) - - t.Run("no db params", func(t *testing.T) { - dao, _ := scanSetup(t) - - err := dao.Delete(nil, nil) - require.ErrorIs(t, err, ErrMissingWhere) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := scanSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - err = dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": "1234"}}, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestScan_DeleteCascade(t *testing.T) { - dao, db := scanSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Scan().Build() - - // Delete the course - courseDao := NewCourseDao(db) - err := courseDao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": testData[0].ID}}, nil) - require.Nil(t, err) - - // Check the scan was deleted - s, err := dao.Get(testData[0].ID, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, s) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestScan_NextScan(t *testing.T) { - t.Run("first", func(t *testing.T) { - dao, db := scanSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Scan().Build() - - s, err := dao.Next(nil) - require.Nil(t, err) - require.Equal(t, testData[0].Scan.ID, s.ID) - require.Equal(t, testData[0].Path, s.CoursePath) - }) - - t.Run("next", func(t *testing.T) { - dao, db := scanSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(3).Scan().Build() - - // Update the the first scan to processing - testData[0].Scan.Status = types.NewScanStatus(types.ScanStatusProcessing) - require.Nil(t, dao.Update(testData[0].Scan, nil)) - - s, err := dao.Next(nil) - require.Nil(t, err) - require.Equal(t, testData[1].Scan.ID, s.ID) - require.Equal(t, testData[1].Path, s.CoursePath) - - }) - - t.Run("empty", func(t *testing.T) { - dao, _ := scanSetup(t) - - scan, err := dao.Next(nil) - require.Nil(t, err) - require.Nil(t, scan) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := scanSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Next(nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} diff --git a/daos/tag.go b/daos/tag.go deleted file mode 100644 index 20facf1..0000000 --- a/daos/tag.go +++ /dev/null @@ -1,342 +0,0 @@ -package daos - -import ( - "slices" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// TagDao is the data access object for tags -type TagDao struct { - db database.Database - table string -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NewTagDao returns a new TagDao -func NewTagDao(db database.Database) *TagDao { - return &TagDao{ - db: db, - table: "tags", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Table returns the table name -func (dao *TagDao) Table() string { - return dao.table -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Count returns the number of tags -func (dao *TagDao) Count(params *database.DatabaseParams, tx *database.Tx) (int, error) { - generic := NewGenericDao(dao.db, dao) - return generic.Count(params, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Create inserts a new tag -// -// `tx` allows for the function to be run within a transaction -func (dao *TagDao) Create(t *models.Tag, tx *database.Tx) error { - execFn := dao.db.Exec - if tx != nil { - execFn = tx.Exec - } - - if t.ID == "" { - t.RefreshId() - } - - t.RefreshCreatedAt() - t.RefreshUpdatedAt() - - query, args, _ := squirrel. - StatementBuilder. - Insert(dao.Table()). - SetMap(dao.data(t)). - ToSql() - - _, err := execFn(query, args...) - - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Get selects a tag with the given ID or name -// -// `tx` allows for the function to be run within a transaction -func (dao *TagDao) Get(id string, byName bool, dbParams *database.DatabaseParams, tx *database.Tx) (*models.Tag, error) { - generic := NewGenericDao(dao.db, dao) - - tagDbParams := &database.DatabaseParams{ - Columns: dao.columns(), - } - - if byName { - if dbParams != nil && dbParams.CaseInsensitive { - tagDbParams.Where = squirrel.Eq{dao.Table() + ".tag COLLATE NOCASE": id} - } else { - tagDbParams.Where = squirrel.Eq{dao.Table() + ".tag": id} - } - } else { - tagDbParams.Where = squirrel.Eq{dao.Table() + ".id": id} - } - - row, err := generic.Get(tagDbParams, tx) - if err != nil { - return nil, err - } - - tag, err := dao.scanRow(row) - if err != nil { - return nil, err - } - - // Get the course tags - courseTagDao := NewCourseTagDao(dao.db) - if dbParams != nil && slices.Contains(dbParams.IncludeRelations, courseTagDao.Table()) { - courseTagDbParams := &database.DatabaseParams{ - OrderBy: courseTagDao.ProcessOrderBy(dbParams.OrderBy, true), - Where: squirrel.Eq{"tag_id": id}, - } - - // Get the course_tags - courseTags, err := courseTagDao.List(courseTagDbParams, tx) - if err != nil { - return nil, err - } - - tag.CourseTags = courseTags - } - - return tag, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// List selects tags -// -// `tx` allows for the function to be run within a transaction -func (dao *TagDao) List(dbParams *database.DatabaseParams, tx *database.Tx) ([]*models.Tag, error) { - generic := NewGenericDao(dao.db, dao) - - if dbParams == nil { - dbParams = &database.DatabaseParams{} - } - - origOrderBy := dbParams.OrderBy - - dbParams.OrderBy = dao.ProcessOrderBy(dbParams.OrderBy, false) - - // Default the columns if not specified - if len(dbParams.Columns) == 0 { - dbParams.Columns = dao.columns() - } - - rows, err := generic.List(dbParams, tx) - if err != nil { - return nil, err - } - defer rows.Close() - - var tags []*models.Tag - tagIds := []string{} - - for rows.Next() { - t, err := dao.scanRow(rows) - if err != nil { - return nil, err - } - - tags = append(tags, t) - tagIds = append(tagIds, t.ID) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - // Get the course_tags - courseTagDao := NewCourseTagDao(dao.db) - if len(tags) > 0 && slices.Contains(dbParams.IncludeRelations, courseTagDao.Table()) { - // Reduce the order by clause to only include columns specific to the course_tags table - reducedOrderBy := courseTagDao.ProcessOrderBy(origOrderBy, true) - - dbParams = &database.DatabaseParams{ - OrderBy: reducedOrderBy, - Where: squirrel.Eq{"tag_id": tagIds}, - } - - // Get the course_tags - courseTags, err := courseTagDao.List(dbParams, tx) - if err != nil { - return nil, err - } - - // Map the course_tags to the tags - tagMap := map[string][]*models.CourseTag{} - for _, ct := range courseTags { - tagMap[ct.TagId] = append(tagMap[ct.TagId], ct) - } - - // Assign the course_tags to the tags - for _, t := range tags { - t.CourseTags = tagMap[t.ID] - } - } - - return tags, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Update updates a tag -// -// Note: Only `tag` can be updated -func (dao *TagDao) Update(tag *models.Tag, tx *database.Tx) error { - if tag.ID == "" { - return ErrEmptyId - } - - tag.RefreshUpdatedAt() - - query, args, _ := squirrel. - StatementBuilder. - Update(dao.Table()). - Set("tag", NilStr(tag.Tag)). - Set("updated_at", FormatTime(tag.UpdatedAt)). - Where("id = ?", tag.ID). - ToSql() - - execFn := dao.db.Exec - if tx != nil { - execFn = tx.Exec - } - - _, err := execFn(query, args...) - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Delete deletes a tag based upon the where clause -// -// `tx` allows for the function to be run within a transaction -func (dao *TagDao) Delete(dbParams *database.DatabaseParams, tx *database.Tx) error { - if dbParams == nil || dbParams.Where == nil { - return ErrMissingWhere - } - - generic := NewGenericDao(dao.db, dao) - return generic.Delete(dbParams, tx) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// ProcessOrderBy takes an array of strings representing orderBy clauses and returns a processed -// version of this array -// -// It will creates a new list of valid table columns based upon columns() for the current -// DAO -func (dao *TagDao) ProcessOrderBy(orderBy []string, explicit bool) []string { - if len(orderBy) == 0 { - return orderBy - } - - generic := NewGenericDao(dao.db, dao) - return generic.ProcessOrderBy(orderBy, dao.columns(), explicit) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Internal -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// countSelect returns the default count select builder -func (dao *TagDao) countSelect() squirrel.SelectBuilder { - courseTagDao := NewCourseTagDao(dao.db) - - return squirrel. - StatementBuilder. - PlaceholderFormat(squirrel.Question). - Select(""). - From(dao.Table()). - LeftJoin(courseTagDao.Table() + " ON " + dao.Table() + ".id = " + courseTagDao.Table() + ".tag_id"). - RemoveColumns() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// baseSelect returns the default select builder -// -// Note: The columns are removed, so you must specify the columns with `.Columns(...)` when using -// this select builder -func (dao *TagDao) baseSelect() squirrel.SelectBuilder { - return dao.countSelect().GroupBy("tags.id") - // GroupBy("tags.id", "tags.tag", "tags.created_at", "tags.updated_at"). -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// columns returns the columns to select -func (dao *TagDao) columns() []string { - courseTagDao := NewCourseTagDao(dao.db) - - return []string{ - dao.Table() + ".*", - "COALESCE(COUNT(" + courseTagDao.Table() + ".id), 0) AS course_count", - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// data generates a map of key/values for a tag -func (dao *TagDao) data(t *models.Tag) map[string]any { - return map[string]any{ - "id": t.ID, - "tag": NilStr(t.Tag), - "created_at": FormatTime(t.CreatedAt), - "updated_at": FormatTime(t.UpdatedAt), - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// scanRow scans an tag row -func (dao *TagDao) scanRow(scannable Scannable) (*models.Tag, error) { - var t models.Tag - - var createdAt string - var updatedAt string - - err := scannable.Scan( - &t.ID, - &t.Tag, - &createdAt, - &updatedAt, - &t.CourseCount, - ) - - if err != nil { - return nil, err - } - - if t.CreatedAt, err = ParseTime(createdAt); err != nil { - return nil, err - } - - if t.UpdatedAt, err = ParseTime(updatedAt); err != nil { - return nil, err - } - - return &t, nil -} diff --git a/daos/tag_test.go b/daos/tag_test.go deleted file mode 100644 index 4cf3bb9..0000000 --- a/daos/tag_test.go +++ /dev/null @@ -1,463 +0,0 @@ -package daos - -import ( - "database/sql" - "fmt" - "testing" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils/pagination" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func tagSetup(t *testing.T) (*TagDao, database.Database) { - t.Helper() - - dbManager := setup(t) - tagDao := NewTagDao(dbManager.DataDb) - return tagDao, dbManager.DataDb -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestTag_Count(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := tagSetup(t) - - count, err := dao.Count(nil, nil) - require.Nil(t, err) - require.Zero(t, count) - }) - - t.Run("entries", func(t *testing.T) { - dao, _ := tagSetup(t) - - // Add test_tags into the database - for _, tag := range test_tags { - require.Nil(t, dao.Create(&models.Tag{Tag: tag}, nil)) - } - - count, err := dao.Count(nil, nil) - require.Nil(t, err) - require.Equal(t, count, len(test_tags)) - }) - - t.Run("where", func(t *testing.T) { - dao, _ := tagSetup(t) - - // Add test_tags into the database - for _, tag := range test_tags { - require.Nil(t, dao.Create(&models.Tag{Tag: tag}, nil)) - } - - // ---------------------------- - // EQUALS - // ---------------------------- - count, err := dao.Count(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".tag": test_tags[0]}}, nil) - require.Nil(t, err) - require.Equal(t, 1, count) - - // ---------------------------- - // NOT EQUALS - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.NotEq{dao.Table() + ".tag": test_tags[0]}}, nil) - require.Nil(t, err) - require.Equal(t, 19, count) - - // ---------------------------- - // STARTS WITH (Java%) - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.Like{dao.Table() + ".tag": "Java%"}}, nil) - require.Nil(t, err) - require.Equal(t, 2, count) - - // ---------------------------- - // ERROR - // ---------------------------- - count, err = dao.Count(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) - require.ErrorContains(t, err, "syntax error") - require.Zero(t, count) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := tagSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Count(nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestTag_Create(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, _ := tagSetup(t) - - tag := &models.Tag{ - Tag: "JavaScript", - } - - err := dao.Create(tag, nil) - require.Nil(t, err) - }) - - t.Run("duplicate tags", func(t *testing.T) { - dao, _ := tagSetup(t) - - tag := &models.Tag{ - Tag: "JavaScript", - } - - // Create the tag - require.Nil(t, dao.Create(tag, nil)) - - // Create the asset (again) - require.ErrorContains(t, dao.Create(tag, nil), fmt.Sprintf("UNIQUE constraint failed: %s.tag", dao.Table())) - }) - - t.Run("constraints", func(t *testing.T) { - dao, _ := tagSetup(t) - - // Empty tag ID - tag := &models.Tag{} - require.ErrorContains(t, dao.Create(tag, nil), fmt.Sprintf("NOT NULL constraint failed: %s.tag", dao.Table())) - - // Success - tag.Tag = "JavaScript" - require.Nil(t, dao.Create(tag, nil)) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestTag_Get(t *testing.T) { - t.Run("found", func(t *testing.T) { - dao, db := tagSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses([]string{"course 1", "course 2"}).Tags([]string{"Go", "TypeScript"}).Build() - - // Get the first tag - tag, err := dao.Get(testData[0].Tags[0].TagId, false, nil, nil) - require.Nil(t, err) - require.Equal(t, testData[0].Tags[0].TagId, tag.ID) - - // By Name (Go) - tag, err = dao.Get("Go", true, nil, nil) - require.Nil(t, err) - require.Equal(t, testData[0].Tags[0].TagId, tag.ID) - require.Equal(t, 2, tag.CourseCount) - - // ---------------------------- - // Course tags - // ---------------------------- - dbParams := &database.DatabaseParams{ - OrderBy: []string{NewCourseDao(dao.db).Table() + ".title asc"}, - IncludeRelations: []string{NewCourseTagDao(dao.db).Table()}, - } - - tag, err = dao.Get(testData[0].Tags[0].TagId, false, dbParams, nil) - require.Nil(t, err) - require.Len(t, tag.CourseTags, 2) - require.Equal(t, testData[0].ID, tag.CourseTags[0].CourseId) - require.Equal(t, testData[1].ID, tag.CourseTags[1].CourseId) - - // ---------------------------- - // Case Insensitive - // ---------------------------- - dbParams = &database.DatabaseParams{ - CaseInsensitive: true, - } - - tag, err = dao.Get("go", true, dbParams, nil) - require.Nil(t, err) - require.Equal(t, testData[0].Tags[0].TagId, tag.ID) - - }) - - t.Run("not found", func(t *testing.T) { - dao, _ := tagSetup(t) - - c, err := dao.Get("1234", false, nil, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, c) - }) - - t.Run("empty id", func(t *testing.T) { - dao, _ := tagSetup(t) - - c, err := dao.Get("", false, nil, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, c) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := tagSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.Get("1234", false, nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestTag_List(t *testing.T) { - t.Run("no entries", func(t *testing.T) { - dao, _ := tagSetup(t) - - tags, err := dao.List(nil, nil) - require.Nil(t, err) - require.Zero(t, tags) - }) - - t.Run("found", func(t *testing.T) { - dao, _ := tagSetup(t) - - NewTestBuilder(t).Db(dao.db).Courses([]string{"course 1"}).Tags([]string{"PHP", "Go"}).Build() - NewTestBuilder(t).Db(dao.db).Courses([]string{"course 2"}).Tags([]string{"Go", "C"}).Build() - NewTestBuilder(t).Db(dao.db).Courses([]string{"course 3"}).Tags([]string{"C", "TypeScript"}).Build() - - dbParams := &database.DatabaseParams{ - OrderBy: []string{dao.Table() + ".tag asc"}, - } - - result, err := dao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, result, 4) - require.Nil(t, result[0].CourseTags) - - require.Equal(t, 2, result[0].CourseCount) // C - require.Equal(t, 2, result[1].CourseCount) // GO - require.Equal(t, 1, result[2].CourseCount) // PHP - require.Equal(t, 1, result[3].CourseCount) // TypeScript - - // ---------------------------- - // Course tags - // ---------------------------- - - dbParams.IncludeRelations = []string{NewCourseTagDao(dao.db).Table()} - - result, err = dao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, result, 4) - - require.Len(t, result[0].CourseTags, 2) // C - require.Len(t, result[1].CourseTags, 2) // GO - require.Len(t, result[2].CourseTags, 1) // PHP - require.Len(t, result[3].CourseTags, 1) // TypeScript - - }) - - t.Run("orderby", func(t *testing.T) { - dao, _ := tagSetup(t) - - testData := NewTestBuilder(t). - Db(dao.db). - Courses([]string{"course 1", "course 2", "course 3"}). - Tags([]string{"PHP", "Go", "Java", "TypeScript", "C"}).Build() - - // ---------------------------- - // TAG DESC - // ---------------------------- - result, err := dao.List(&database.DatabaseParams{OrderBy: []string{"tag desc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 5) - require.Equal(t, "TypeScript", result[0].Tag) - - // ---------------------------- - // TAG ASC - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{OrderBy: []string{"tag asc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 5) - require.Equal(t, "C", result[0].Tag) - - // ---------------------------- - // CREATED_AT ASC + COURSES.TITLE DESC - // ---------------------------- - dbParams := &database.DatabaseParams{ - OrderBy: []string{"tag asc", NewCourseDao(dao.db).Table() + ".title desc"}, - IncludeRelations: []string{NewCourseTagDao(dao.db).Table()}, - } - - result, err = dao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, result, 5) - require.Equal(t, "C", result[0].Tag) - require.Equal(t, testData[2].ID, result[0].CourseTags[0].CourseId) - - // ---------------------------- - // Error - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{OrderBy: []string{"unit_test asc"}}, nil) - require.ErrorContains(t, err, "no such column") - require.Nil(t, result) - }) - - t.Run("where", func(t *testing.T) { - dao, _ := tagSetup(t) - - for _, tag := range test_tags { - require.Nil(t, dao.Create(&models.Tag{Tag: tag}, nil)) - } - - // ---------------------------- - // EQUALS (PHP) - // ---------------------------- - result, err := dao.List(&database.DatabaseParams{Where: squirrel.Eq{dao.Table() + ".tag": "PHP"}}, nil) - require.Nil(t, err) - require.Len(t, result, 1) - - // ---------------------------- - // LIKE (Java%) - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{Where: squirrel.Like{dao.Table() + ".tag": "Java%"}}, nil) - require.Nil(t, err) - require.Len(t, result, 2) - - // ---------------------------- - // ERROR - // ---------------------------- - result, err = dao.List(&database.DatabaseParams{Where: squirrel.Eq{"": ""}}, nil) - require.ErrorContains(t, err, "syntax error") - require.Nil(t, result) - }) - - t.Run("pagination", func(t *testing.T) { - dao, _ := tagSetup(t) - - for _, tag := range test_tags { - require.Nil(t, dao.Create(&models.Tag{Tag: tag}, nil)) - } - - // ---------------------------- - // Page 1 with 10 items - // ---------------------------- - p := pagination.New(1, 10) - - result, err := dao.List(&database.DatabaseParams{Pagination: p, OrderBy: []string{"tag asc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 10) - require.Equal(t, 20, p.TotalItems()) - require.Equal(t, "C", result[0].Tag) - - // ---------------------------- - // Page 2 with 10 items - // ---------------------------- - p = pagination.New(2, 10) - - result, err = dao.List(&database.DatabaseParams{Pagination: p, OrderBy: []string{"tag asc"}}, nil) - require.Nil(t, err) - require.Len(t, result, 10) - require.Equal(t, 20, p.TotalItems()) - require.Equal(t, "Perl", result[0].Tag) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := tagSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - _, err = dao.List(nil, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestTag_Update(t *testing.T) { - t.Run("tag", func(t *testing.T) { - dao, db := tagSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Tags([]string{"Go"}).Build() - - tag, err := dao.Get(testData[0].Tags[0].TagId, false, nil, nil) - require.Nil(t, err) - require.Equal(t, "Go", tag.Tag) - - // Update the tag - tag.Tag = "go" - require.Nil(t, dao.Update(tag, nil)) - - updatedTag, err := dao.Get(testData[0].Tags[0].TagId, false, nil, nil) - require.Nil(t, err) - require.Equal(t, "go", updatedTag.Tag) - }) - - t.Run("empty id", func(t *testing.T) { - dao, _ := tagSetup(t) - - err := dao.Update(&models.Tag{}, nil) - require.ErrorIs(t, err, ErrEmptyId) - }) - - t.Run("invalid id", func(t *testing.T) { - dao, db := tagSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Tags(1).Build() - - tag, err := dao.Get(testData[0].Tags[0].TagId, false, nil, nil) - require.Nil(t, err) - - tag.ID = "1234" - require.Nil(t, dao.Update(tag, nil)) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := tagSetup(t) - - testData := NewTestBuilder(t).Db(db).Courses(1).Tags(1).Build() - - tag, err := dao.Get(testData[0].Tags[0].TagId, false, nil, nil) - require.Nil(t, err) - - _, err = db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - err = dao.Update(tag, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestTag_Delete(t *testing.T) { - t.Run("success", func(t *testing.T) { - dao, _ := tagSetup(t) - - // Add test_tags into the database - for _, tag := range test_tags { - require.Nil(t, dao.Create(&models.Tag{Tag: tag}, nil)) - } - - err := dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"tag": test_tags[0]}}, nil) - require.Nil(t, err) - }) - - t.Run("no db params", func(t *testing.T) { - dao, _ := tagSetup(t) - - err := dao.Delete(nil, nil) - require.ErrorIs(t, err, ErrMissingWhere) - }) - - t.Run("db error", func(t *testing.T) { - dao, db := tagSetup(t) - - _, err := db.Exec("DROP TABLE IF EXISTS " + dao.Table()) - require.Nil(t, err) - - err = dao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"tag": "1234"}}, nil) - require.ErrorContains(t, err, "no such table: "+dao.Table()) - }) -} diff --git a/daos/test_data.go b/daos/test_data.go deleted file mode 100644 index 568215b..0000000 --- a/daos/test_data.go +++ /dev/null @@ -1,326 +0,0 @@ -package daos - -import ( - "database/sql" - "fmt" - "math/rand" - "path/filepath" - "testing" - "time" - - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils/security" - "github.com/geerew/off-course/utils/types" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Slice of 20 tags for testing (programming languages) -var test_tags = []string{ - "JavaScript", "Python", "Java", "Ruby", "PHP", - "TypeScript", "C#", "C++", "C", "Swift", - "Kotlin", "Rust", "Go", "Perl", "Scala", - "R", "Objective-C", "Shell", "PowerShell", "Haskell", -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -type TestBuilder struct { - // A testing object. Used to validate assertions into the DB - t *testing.T - - // The database - db database.Database - - // How many courses to create OR a list of course titles - numberOfCourses int - courseTitles []string - - // Whether to create a scan per course - scan bool - - // How many assets per course - assetsPerCourse int - // How many attachments per asset - attachmentsPerAsset int - - // How many tags per course OR a list of tags - tagsPerCourse int - tags []string -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -type TestCourse struct { - *models.Course - Scan *models.Scan - Assets []*models.Asset - Tags []*models.CourseTag -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func NewTestBuilder(t *testing.T) *TestBuilder { - return &TestBuilder{ - t: t, - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Db sets the database -func (builder *TestBuilder) Db(db database.Database) *TestBuilder { - builder.db = db - return builder -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Courses sets the number of courses -func (builder *TestBuilder) Courses(courses any) *TestBuilder { - switch c := courses.(type) { - case int: - builder.numberOfCourses = c - case []string: - builder.courseTitles = c - builder.numberOfCourses = len(c) - } - return builder -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Scan sets a scan per course -func (builder *TestBuilder) Scan() *TestBuilder { - builder.scan = true - return builder -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Assets sets the number of assets per course -func (builder *TestBuilder) Assets(assetsPerCourse int) *TestBuilder { - builder.assetsPerCourse = assetsPerCourse - return builder -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Attachments sets the number of attachments per asset -func (builder *TestBuilder) Attachments(attachmentsPerAsset int) *TestBuilder { - builder.attachmentsPerAsset = attachmentsPerAsset - return builder -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Tags sets either a random number of tags per course or a specific set of tags -func (builder *TestBuilder) Tags(tags any) *TestBuilder { - switch t := tags.(type) { - case int: - builder.tagsPerCourse = t - case []string: - builder.tags = t - builder.tagsPerCourse = len(t) - } - return builder -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func (builder *TestBuilder) Build() []*TestCourse { - var testCourses []*TestCourse - - for i := 0; i < builder.numberOfCourses; i++ { - tc := &TestCourse{} - - title := "" - if len(builder.courseTitles) > 0 { - title = builder.courseTitles[i] - } else { - title = fmt.Sprintf("Course %d", i+1) - } - - tc.Course = builder.newTestCourse(title) - - if builder.scan { - tc.Scan = builder.newTestScan(tc.Course.ID) - } - - if builder.assetsPerCourse > 0 { - tc.Assets = builder.newTestAssets(tc.Course) - } - - if builder.tagsPerCourse > 0 && builder.db != nil { - tc.Tags = builder.newTestTags(tc.Course) - } - - testCourses = append(testCourses, tc) - } - - return testCourses -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func (builder *TestBuilder) newTestCourse(title string) *models.Course { - c := &models.Course{} - - c.RefreshId() - c.RefreshCreatedAt() - c.RefreshUpdatedAt() - - c.Title = title - c.Path = filepath.Join(string(filepath.Separator), "courses", c.Title) - - if builder.db != nil { - dao := NewCourseDao(builder.db) - - err := dao.Create(c) - require.NoError(builder.t, err, "Failed to create course") - - time.Sleep(time.Millisecond * 1) - } - - return c -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func (builder *TestBuilder) newTestScan(courseId string) *models.Scan { - s := &models.Scan{} - - s.RefreshId() - s.RefreshCreatedAt() - s.RefreshUpdatedAt() - - s.CourseID = courseId - s.Status = types.NewScanStatus(types.ScanStatusWaiting) - - if builder.db != nil { - dao := NewScanDao(builder.db) - - err := dao.Create(s, nil) - require.Nil(builder.t, err) - - time.Sleep(time.Millisecond * 1) - } - - return s -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func (builder *TestBuilder) newTestAssets(course *models.Course) []*models.Asset { - assets := []*models.Asset{} - - for i := 0; i < builder.assetsPerCourse; i++ { - a := &models.Asset{} - - a.RefreshId() - a.RefreshCreatedAt() - a.RefreshUpdatedAt() - - a.CourseID = course.ID - a.Title = fmt.Sprintf("asset %d", i+1) - a.Prefix = sql.NullInt16{Int16: int16(i + 1), Valid: true} - a.Chapter = fmt.Sprintf("%d chapter %s", i+1, security.PseudorandomString(2)) - a.Type = *types.NewAsset("mp4") - a.Path = filepath.Join(course.Path, a.Chapter, fmt.Sprintf("%d", a.Prefix.Int16), a.Title+a.Type.String()) - a.Hash = security.PseudorandomString(32) - - if builder.db != nil { - dao := NewAssetDao(builder.db) - - err := dao.Create(a, nil) - require.Nil(builder.t, err) - - time.Sleep(time.Millisecond * 1) - } - - if builder.attachmentsPerAsset > 0 { - a.Attachments = builder.newTestAttachments(a) - } - - assets = append(assets, a) - } - - return assets -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func (builder *TestBuilder) newTestAttachments(asset *models.Asset) []*models.Attachment { - attachments := []*models.Attachment{} - - for i := 0; i < builder.attachmentsPerAsset; i++ { - a := &models.Attachment{} - - a.RefreshId() - a.RefreshCreatedAt() - a.RefreshUpdatedAt() - - a.CourseID = asset.CourseID - a.AssetID = asset.ID - a.Title = fmt.Sprintf("attachment %d", i+1) - a.Path = filepath.Join(filepath.Dir(asset.Path), fmt.Sprintf("%d", asset.Prefix.Int16), a.Title) - - if builder.db != nil { - dao := NewAttachmentDao(builder.db) - - err := dao.Create(a, nil) - require.Nil(builder.t, err) - - time.Sleep(time.Millisecond * 1) - } - - attachments = append(attachments, a) - - } - - return attachments -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func (builder *TestBuilder) newTestTags(course *models.Course) []*models.CourseTag { - if builder.db == nil { - return nil - } - - tags := []*models.CourseTag{} - chosenTags := map[string]bool{} - - for i := 0; i < builder.tagsPerCourse; i++ { - var tag string - - if len(builder.tags) > 0 { - tag = builder.tags[i] - } else { - for { - randomTag := test_tags[rand.Intn(len(test_tags))] - if !chosenTags[randomTag] { - tag = randomTag - chosenTags[randomTag] = true - break - } - } - } - - ct := &models.CourseTag{ - CourseId: course.ID, - Tag: tag, - } - - dao := NewCourseTagDao(builder.db) - require.Nil(builder.t, dao.Create(ct, nil)) - - tags = append(tags, ct) - - time.Sleep(time.Millisecond * 1) - } - - return tags -} diff --git a/database/base.go b/database/base.go index 7b2cd34..52d13b5 100644 --- a/database/base.go +++ b/database/base.go @@ -1,6 +1,7 @@ package database import ( + "context" "database/sql" "log/slog" @@ -19,6 +20,31 @@ var ( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +type contextKey string + +const querierKey = contextKey("querier") + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// WithQuerier adds a querier to the context +func WithQuerier(ctx context.Context, querier Querier) context.Context { + return context.WithValue(ctx, querierKey, querier) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// QuerierFromContext returns the querier from the context, defaulting to a defaulted querier if +// not found +func QuerierFromContext(ctx context.Context, defaultQuerier Querier) Querier { + if querier, ok := ctx.Value(querierKey).(Querier); ok && querier != nil { + return querier + } + + return defaultQuerier +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Defines the sql functions type ( ExecFn = func(query string, args ...interface{}) (sql.Result, error) @@ -30,18 +56,23 @@ type ( // Database defines the interface for a database type Database interface { + Querier + RunInTransaction(context.Context, func(context.Context) error) error + SetLogger(*slog.Logger) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +type Querier interface { Exec(query string, args ...any) (sql.Result, error) Query(query string, args ...any) (*sql.Rows, error) QueryRow(query string, args ...any) *sql.Row - Begin(opts *sql.TxOptions) (*sql.Tx, error) - RunInTransaction(txFunc func(*Tx) error) error - SetLogger(logger *slog.Logger) } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// DatabaseParams defines optional params for a database query -type DatabaseParams struct { +// Options defines optional params for a database query +type Options struct { // A slice of columns to order by (ex ["id DESC", "title ASC"]) OrderBy []string @@ -53,16 +84,16 @@ type DatabaseParams struct { // // Examples: // - // EQ: sq.Eq{"id": "123"} - // IN: sq.Eq{"id": []string{"123", "456"}} - // OR: sq.Or{sq.Expr("id = ?", "123"), sq.Expr("id = ?", "456")} - // AND: sq.And{sq.Eq{"id": "123"}, sq.Eq{"title": "devops"}} - // LIKE: sq.Like{"title": "%dev%"} - // NOT: sq.NotEq{"id": "123"} - Where any + // EQ: squirrel.Eq{"id": "123"} + // IN: squirrel.Eq{"id": []string{"123", "456"}} + // OR: squirrel.Or{squirrel.Expr("id = ?", "123"), squirrel.Expr("id = ?", "456")} + // AND: squirrel.And{squirrel.Eq{"id": "123"}, squirrel.Eq{"title": "devops"}} + // LIKE: squirrel.Like{"title": "%dev%"} + // NOT: squirrel.NotEq{"id": "123"} + Where squirrel.Sqlizer // Columns to group by - GroupBys []string + GroupBy []string // Limit the results Having squirrel.Sqlizer diff --git a/database/base_test.go b/database/base_test.go index 8231375..abfb071 100644 --- a/database/base_test.go +++ b/database/base_test.go @@ -17,7 +17,7 @@ func Test_NewSqliteDBManager(t *testing.T) { BatchSize: 1, WriteFn: logger.NilWriteFn(), }) - require.Nil(t, err) + require.NoError(t, err) appFs := appFs.NewAppFs(afero.NewMemMapFs(), logger) @@ -28,7 +28,7 @@ func Test_NewSqliteDBManager(t *testing.T) { InMemory: true, }) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, dbManager) }) @@ -38,7 +38,7 @@ func Test_NewSqliteDBManager(t *testing.T) { BatchSize: 1, WriteFn: logger.NilWriteFn(), }) - require.Nil(t, err) + require.NoError(t, err) appFs := appFs.NewAppFs(afero.NewReadOnlyFs(afero.NewMemMapFs()), logger) diff --git a/database/sqlite.go b/database/sqlite.go index 7f66246..96b4e6a 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -1,6 +1,7 @@ package database import ( + "context" "database/sql" "fmt" "log/slog" @@ -102,20 +103,17 @@ func (db *SqliteDb) Exec(query string, args ...any) (sql.Result, error) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Begin starts a new transaction -// -// It implements the Database interface -func (db *SqliteDb) Begin(opts *sql.TxOptions) (*sql.Tx, error) { - return db.DB.Begin() -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // RunInTransaction runs a function in a transaction // // It implements the Database interface -func (db *SqliteDb) RunInTransaction(txFunc func(*Tx) error) error { - slqTx, err := db.DB.Begin() +func (db *SqliteDb) RunInTransaction(ctx context.Context, txFunc func(context.Context) error) (err error) { + // Check if there's an existing querier in the context + existingQuerier := QuerierFromContext(ctx, nil) + if existingQuerier != nil { + return txFunc(ctx) + } + + slqTx, err := db.DB.BeginTx(ctx, nil) if err != nil { return err } @@ -125,6 +123,8 @@ func (db *SqliteDb) RunInTransaction(txFunc func(*Tx) error) error { db: db, } + txCtx := WithQuerier(ctx, tx) + defer func() { if p := recover(); p != nil { tx.Rollback() @@ -136,7 +136,7 @@ func (db *SqliteDb) RunInTransaction(txFunc func(*Tx) error) error { } }() - err = txFunc(tx) + err = txFunc(txCtx) return err } diff --git a/database/sqlite_test.go b/database/sqlite_test.go index bf9df0a..8c00e07 100644 --- a/database/sqlite_test.go +++ b/database/sqlite_test.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "testing" @@ -33,12 +34,12 @@ func setupSqliteDB(t *testing.T) *DatabaseManager { InMemory: true, }) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, dbManager) // Test table _, err = dbManager.DataDb.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") - require.Nil(t, err) + require.NoError(t, err) return dbManager } @@ -51,7 +52,7 @@ func TestSqliteDb_Bootstrap(t *testing.T) { BatchSize: 1, WriteFn: logger.NilWriteFn(), }) - require.Nil(t, err) + require.NoError(t, err) appFs := appFs.NewAppFs(afero.NewMemMapFs(), logger) @@ -64,7 +65,7 @@ func TestSqliteDb_Bootstrap(t *testing.T) { InMemory: true, }) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, db) }) @@ -74,7 +75,7 @@ func TestSqliteDb_Bootstrap(t *testing.T) { BatchSize: 1, WriteFn: logger.NilWriteFn(), }) - require.Nil(t, err) + require.NoError(t, err) appFs := appFs.NewAppFs(afero.NewReadOnlyFs(afero.NewMemMapFs()), logger) @@ -97,7 +98,7 @@ func TestSqliteDb_Bootstrap(t *testing.T) { BatchSize: 1, WriteFn: logger.NilWriteFn(), }) - require.Nil(t, err) + require.NoError(t, err) appFs := appFs.NewAppFs(afero.NewMemMapFs(), logger) @@ -122,17 +123,17 @@ func TestSqliteDb_Query(t *testing.T) { dbManager := setupSqliteDB(t) _, err := dbManager.DataDb.Exec("INSERT INTO test (name) VALUES ('test')") - require.Nil(t, err) + require.NoError(t, err) rows, err := dbManager.DataDb.Query("SELECT * FROM test") - require.Nil(t, err) + require.NoError(t, err) defer rows.Close() for rows.Next() { var id int var name string err = rows.Scan(&id, &name) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, 1, id) require.Equal(t, "test", name) } @@ -146,13 +147,13 @@ func TestSqliteDb_QueryRow(t *testing.T) { dbManager := setupSqliteDB(t) _, err := dbManager.DataDb.Exec("INSERT INTO test (name) VALUES ('test')") - require.Nil(t, err) + require.NoError(t, err) var id int var name string err = dbManager.DataDb.QueryRow("SELECT * FROM test").Scan(&id, &name) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, "test", name) } @@ -162,42 +163,24 @@ func TestSqliteDb_Exec(t *testing.T) { dbManager := setupSqliteDB(t) result, err := dbManager.DataDb.Exec("INSERT INTO test (name) VALUES ('test')") - require.Nil(t, err) + require.NoError(t, err) rowAffected, err := result.RowsAffected() - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, int64(1), rowAffected) } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -func TestSqliteDb_Begin(t *testing.T) { - dbManager := setupSqliteDB(t) - - tx, err := dbManager.DataDb.Begin(nil) - require.Nil(t, err) - - _, err = tx.Exec("INSERT INTO test (name) VALUES ('test')") - require.Nil(t, err) - - err = tx.Commit() - require.Nil(t, err) - - var count int - err = dbManager.DataDb.QueryRow("SELECT COUNT(*) FROM test").Scan(&count) - require.Nil(t, err) - require.Equal(t, 1, count) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - func TestSqliteDb_RunInTransaction(t *testing.T) { t.Run(("error"), func(t *testing.T) { dbManager := setupSqliteDB(t) - err := dbManager.DataDb.RunInTransaction(func(tx *Tx) error { - _, err := tx.Exec("INSERT INTO test (name) VALUES ('test')") + ctx := context.Background() + err := dbManager.DataDb.RunInTransaction(ctx, func(txCtx context.Context) error { + q := QuerierFromContext(txCtx, dbManager.DataDb) + _, err := q.Exec("INSERT INTO test (name) VALUES ('test')") if err != nil { return err } @@ -214,15 +197,17 @@ func TestSqliteDb_RunInTransaction(t *testing.T) { var count int err = dbManager.DataDb.QueryRow("SELECT COUNT(*) FROM test").Scan(&count) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, 0, count) }) t.Run(("success"), func(t *testing.T) { dbManager := setupSqliteDB(t) - err := dbManager.DataDb.RunInTransaction(func(tx *Tx) error { - _, err := tx.Exec("INSERT INTO test (name) VALUES ('test')") + ctx := context.Background() + err := dbManager.DataDb.RunInTransaction(ctx, func(txCtx context.Context) error { + q := QuerierFromContext(txCtx, dbManager.DataDb) + _, err := q.Exec("INSERT INTO test (name) VALUES ('test')") if err != nil { return err } @@ -230,11 +215,11 @@ func TestSqliteDb_RunInTransaction(t *testing.T) { return nil }) - require.Nil(t, err) + require.NoError(t, err) var count int err = dbManager.DataDb.QueryRow("SELECT COUNT(*) FROM test").Scan(&count) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, 1, count) }) } diff --git a/go.mod b/go.mod index a31dd34..9924e3a 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22.4 require ( github.com/Masterminds/squirrel v1.5.4 github.com/gofiber/fiber/v2 v2.52.5 + github.com/mattn/go-sqlite3 v1.14.22 github.com/pressly/goose/v3 v3.19.2 github.com/robfig/cron/v3 v3.0.1 github.com/shirou/gopsutil/v3 v3.24.3 diff --git a/main.go b/main.go index 52a14c1..24c1313 100644 --- a/main.go +++ b/main.go @@ -13,11 +13,11 @@ import ( "github.com/geerew/off-course/api" "github.com/geerew/off-course/cron" - "github.com/geerew/off-course/daos" + "github.com/geerew/off-course/dao" "github.com/geerew/off-course/database" "github.com/geerew/off-course/models" "github.com/geerew/off-course/utils/appFs" - "github.com/geerew/off-course/utils/jobs" + "github.com/geerew/off-course/utils/coursescan" "github.com/geerew/off-course/utils/logger" "github.com/geerew/off-course/utils/security" "github.com/spf13/afero" @@ -35,6 +35,8 @@ func main() { isDebug := flag.Bool("debug", false, "verbose") flag.Parse() + ctx := context.Background() + // Create app filesystem appFs := appFs.NewAppFs(afero.NewOsFs(), nil) @@ -53,7 +55,7 @@ func main() { logger, loggerDone, err := logger.InitLogger(&logger.BatchOptions{ BatchSize: 200, BeforeAddFn: loggerBeforeAddFn(dbManager.LogsDb), - WriteFn: loggerWriteFn(dbManager.LogsDb), + WriteFn: loggerWriteFn(ctx, dbManager.LogsDb), }) if err != nil { @@ -65,14 +67,14 @@ func main() { appFs.SetLogger(logger) // Course scanner - courseScanner := jobs.NewCourseScanner(&jobs.CourseScannerConfig{ + courseScan := coursescan.NewCourseScan(&coursescan.CourseScanConfig{ Db: dbManager.DataDb, AppFs: appFs, Logger: logger, }) // Start the worker (pass in the func that will process the job) - go courseScanner.Worker(jobs.CourseProcessor, nil) + go courseScan.Worker(ctx, coursescan.Processor, nil) // Initialize cron jobs cron.InitCron(&cron.CronConfig{ @@ -82,13 +84,13 @@ func main() { }) // Create router - router := api.New(&api.RouterConfig{ - DbManager: dbManager, - Logger: logger, - AppFs: appFs, - CourseScanner: courseScanner, - Port: *port, - IsProduction: isProduction, + router := api.NewRouter(&api.RouterConfig{ + DbManager: dbManager, + Logger: logger, + AppFs: appFs, + CourseScan: courseScan, + Port: *port, + IsProduction: isProduction, }) var wg sync.WaitGroup @@ -115,7 +117,7 @@ func main() { fmt.Println("\nShutting down...") // Delete all scans - _, err = dbManager.DataDb.Exec("DELETE FROM " + daos.NewScanDao(dbManager.DataDb).Table()) + _, err = dbManager.DataDb.Exec("DELETE FROM " + models.SCAN_TABLE) if err != nil { log.Fatal("Failed to delete scans", err) } @@ -128,8 +130,6 @@ func main() { // loggerBeforeAddFunc is a logger.BeforeAddFn func loggerBeforeAddFn(db database.Database) logger.BeforeAddFn { - logsDao := daos.NewLogDao(db) - return func(ctx context.Context, log *logger.Log) bool { // Skip calls to the logs API if strings.HasPrefix(log.Message, "GET /api/logs") { @@ -138,10 +138,10 @@ func loggerBeforeAddFn(db database.Database) logger.BeforeAddFn { // This should never happen as the logsDb should be nil, but in the event it is not, skip // logging log writes as it will cause an infinite loop - if strings.HasPrefix(log.Message, "INSERT INTO "+logsDao.Table()) || - strings.HasPrefix(log.Message, "SELECT "+logsDao.Table()) || - strings.HasPrefix(log.Message, "UPDATE "+logsDao.Table()) || - strings.HasPrefix(log.Message, "DELETE FROM "+logsDao.Table()) { + if strings.HasPrefix(log.Message, "INSERT INTO "+models.LOG_TABLE) || + strings.HasPrefix(log.Message, "SELECT "+models.LOG_TABLE) || + strings.HasPrefix(log.Message, "UPDATE "+models.LOG_TABLE) || + strings.HasPrefix(log.Message, "DELETE FROM "+models.LOG_TABLE) { return false } @@ -152,10 +152,12 @@ func loggerBeforeAddFn(db database.Database) logger.BeforeAddFn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // loggerWriteFn returns a logger.WriteFn that writes logs to the database -func loggerWriteFn(db database.Database) logger.WriteFn { +func loggerWriteFn(ctx context.Context, db database.Database) logger.WriteFn { return func(ctx context.Context, logs []*logger.Log) error { + logDao := dao.NewDAO(db) + // Write accumulated logs - db.RunInTransaction(func(tx *database.Tx) error { + db.RunInTransaction(ctx, func(txCtx context.Context) error { model := &models.Log{} for _, l := range logs { @@ -166,7 +168,9 @@ func loggerWriteFn(db database.Database) logger.WriteFn { model.CreatedAt = l.Time model.UpdatedAt = model.CreatedAt - if err := daos.NewLogDao(db).Write(model, tx); err != nil { + // Write the log + err := logDao.WriteLog(txCtx, model) + if err != nil { log.Println("Failed to write log", model, err) } } diff --git a/migrations/data/00001_init.sql b/migrations/data/00001_init.sql index 2c4a422..88b4205 100644 --- a/migrations/data/00001_init.sql +++ b/migrations/data/00001_init.sql @@ -46,7 +46,7 @@ CREATE TABLE assets_progress ( id TEXT PRIMARY KEY NOT NULL, asset_id TEXT NOT NULL UNIQUE, course_id TEXT NOT NULL, - video_pos INTEGER NOT NULL DEFAULT -1, + video_pos INTEGER NOT NULL DEFAULT 0, completed BOOLEAN NOT NULL DEFAULT FALSE, completed_at TEXT, created_at TEXT NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), diff --git a/migrations/data/00002_users.sql b/migrations/data/00002_users.sql new file mode 100644 index 0000000..d91a095 --- /dev/null +++ b/migrations/data/00002_users.sql @@ -0,0 +1,51 @@ +-- +goose Up + +--- Parameters table for application settings +CREATE TABLE params ( + id TEXT PRIMARY KEY NOT NULL, + key TEXT UNIQUE NOT NULL, + value TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + updated_at TEXT NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) +); + +-- Insert initial parameter to track admin creation +INSERT INTO params (id, key, value) VALUES ( + 'wV0418r0Rr', + 'hasAdmin', + 'false' +); + +--- User information +CREATE TABLE users ( + id TEXT PRIMARY KEY NOT NULL, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + role TEXT NOT NULL CHECK(role IN ('admin', 'user')), + created_at TEXT NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + updated_at TEXT NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) +); + +---------------------------------------------------- +-- ALTER SCANS TABLE TO ADD STATUS COLUMN DEFAULT -- +---------------------------------------------------- + +-- Create a new temporary table with the default for status +CREATE TABLE scans_temp ( + id TEXT PRIMARY KEY NOT NULL, + course_id TEXT UNIQUE NOT NULL, + status TEXT NOT NULL DEFAULT 'waiting', + created_at TEXT NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + updated_at TEXT NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + FOREIGN KEY (course_id) REFERENCES courses (id) ON DELETE CASCADE +); + +-- Copy data from the original table to the temporary table +INSERT INTO scans_temp (id, course_id, status, created_at, updated_at) +SELECT id, course_id, status, created_at, updated_at FROM scans; + +-- Drop the original table +DROP TABLE scans; + +-- Rename the temporary table to the original name +ALTER TABLE scans_temp RENAME TO scans; \ No newline at end of file diff --git a/models/asset.go b/models/asset.go index 80bca8d..f6e5fb2 100644 --- a/models/asset.go +++ b/models/asset.go @@ -4,17 +4,16 @@ package models import ( "database/sql" - "time" + "github.com/geerew/off-course/utils/schema" "github.com/geerew/off-course/utils/types" ) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Asset defines the model for an asset (table: assets) +// Asset defines the model for an asset type Asset struct { - BaseModel - + Base CourseID string Title string Prefix sql.NullInt16 @@ -23,15 +22,50 @@ type Asset struct { Path string Hash string - // -------------------------------- - // Not in this table, but added via a join - // -------------------------------- + // Relations + Progress AssetProgress + Attachments []*Attachment +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // Asset Progress - VideoPos int - Completed bool - CompletedAt time.Time +var ( + ASSET_TABLE = "assets" + ASSET_COURSE_ID = "course_id" + ASSET_TITLE = "title" + ASSET_PREFIX = "prefix" + ASSET_CHAPTER = "chapter" + ASSET_TYPE = "type" + ASSET_PATH = "path" + ASSET_HASH = "hash" + ASSET_VIDEO_POSITION = "video_pos" + ASSET_COMPLETED = "completed" + ASSET_COMPLETED_AT = "completed_at" +) - // Attachments - Attachments []*Attachment +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `schema.Modeler` interface by returning the table name +func (a *Asset) Table() string { + return ASSET_TABLE +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Fields implements the `schema.Modeler` interface by defining the model fields +func (a *Asset) Define(s *schema.ModelConfig) { + s.Embedded("Base") + + // Common fields + s.Field("CourseID").Column(ASSET_COURSE_ID).NotNull() + s.Field("Title").Column(COURSE_TITLE).NotNull().Mutable() + s.Field("Prefix").Column(ASSET_PREFIX).Mutable() + s.Field("Chapter").Column(ASSET_CHAPTER).Mutable() + s.Field("Type").Column(ASSET_TYPE).NotNull().Mutable() + s.Field("Path").Column(ASSET_PATH).NotNull().Mutable() + s.Field("Hash").Column(ASSET_HASH).NotNull().Mutable() + + // Relation fields + s.Relation("Progress").MatchOn(ASSET_PROGRESS_ASSET_ID) + s.Relation("Attachments").MatchOn(ATTACHMENT_ASSET_ID) } diff --git a/models/asset_progress.go b/models/asset_progress.go index 35dcc32..af26dc6 100644 --- a/models/asset_progress.go +++ b/models/asset_progress.go @@ -1,18 +1,50 @@ package models -import "time" +import ( + "github.com/geerew/off-course/utils/schema" + "github.com/geerew/off-course/utils/types" +) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// AssetProgress defines the model for a course progress (table: assets_progress) +// AssetProgress defines the model for an asset progress type AssetProgress struct { - BaseModel - + Base AssetID string CourseID string VideoPos int Completed bool - CompletedAt time.Time + CompletedAt types.DateTime +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + ASSET_PROGRESS_TABLE = "assets_progress" + ASSET_PROGRESS_ASSET_ID = "asset_id" + ASSET_PROGRESS_COURSE_ID = "course_id" + ASSET_PROGRESS_VIDEO_POS = "video_pos" + ASSET_PROGRESS_COMPLETED = "completed" + ASSET_PROGRESS_COMPLETED_AT = "completed_at" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `schema.Modeler` interface by returning the table name +func (a *AssetProgress) Table() string { + return ASSET_PROGRESS_TABLE +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Fields implements the `schema.Modeler` interface by defining the model fields +func (a *AssetProgress) Define(s *schema.ModelConfig) { + s.Embedded("Base") + + // Common fields + s.Field("AssetID").Column(ASSET_PROGRESS_ASSET_ID).NotNull() + s.Field("CourseID").Column(ASSET_PROGRESS_COURSE_ID).NotNull() + s.Field("VideoPos").Column(ASSET_PROGRESS_VIDEO_POS).Mutable() + s.Field("Completed").Column(ASSET_PROGRESS_COMPLETED).Mutable() + s.Field("CompletedAt").Column(ASSET_PROGRESS_COMPLETED_AT).Mutable() } diff --git a/models/attachment.go b/models/attachment.go index bc7c47b..fc35c0f 100644 --- a/models/attachment.go +++ b/models/attachment.go @@ -1,13 +1,43 @@ package models +import "github.com/geerew/off-course/utils/schema" + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Attachment defines the model for an attachment (table: attachments) +// Attachment defines the model for an attachment type Attachment struct { - BaseModel - + Base CourseID string AssetID string Title string Path string } + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + ATTACHMENT_TABLE = "attachments" + ATTACHMENT_COURSE_ID = "course_id" + ATTACHMENT_ASSET_ID = "asset_id" + ATTACHMENT_TITLE = "title" + ATTACHMENT_PATH = "path" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `schema.Modeler` interface by returning the table name +func (a *Attachment) Table() string { + return ATTACHMENT_TABLE +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Fields implements the `schema.Modeler` interface by defining the model fields +func (a *Attachment) Define(s *schema.ModelConfig) { + s.Embedded("Base") + + s.Field("CourseID").Column(ATTACHMENT_COURSE_ID).NotNull() + s.Field("AssetID").Column(ATTACHMENT_ASSET_ID).NotNull() + s.Field("Title").Column(ATTACHMENT_TITLE).NotNull().Mutable() + s.Field("Path").Column(ATTACHMENT_PATH).NotNull().Mutable() +} diff --git a/models/base.go b/models/base.go index 3dbd25f..e153e4a 100644 --- a/models/base.go +++ b/models/base.go @@ -1,44 +1,79 @@ package models import ( - "time" - + "github.com/geerew/off-course/utils/schema" "github.com/geerew/off-course/utils/security" + "github.com/geerew/off-course/utils/types" ) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// BaseModel defines the base model for all models -type BaseModel struct { +type Modeler interface { + Table() string + Id() string + RefreshId() + RefreshCreatedAt() + RefreshUpdatedAt() +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Base defines the base model for all models +type Base struct { ID string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt types.DateTime + UpdatedAt types.DateTime +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + BASE_ID = "id" + BASE_CREATED_AT = "created_at" + BASE_UPDATED_AT = "updated_at" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Define implements the `schema.Modeler` interface by defining the model +func (b *Base) Define(s *schema.ModelConfig) { + // Common fields + s.Field("ID").Column(BASE_ID).NotNull() + s.Field("CreatedAt").Column(BASE_CREATED_AT).NotNull() + s.Field("UpdatedAt").Column(BASE_UPDATED_AT).NotNull().Mutable() +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Id returns the model ID +func (b *Base) Id() string { + return b.ID } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // RefreshId generates and sets a new model ID -func (b *BaseModel) RefreshId() { +func (b *Base) RefreshId() { b.ID = security.PseudorandomString(10) } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // SetId sets the model ID -func (b *BaseModel) SetId(id string) { +func (b *Base) SetId(id string) { b.ID = id } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // RefreshCreatedAt updates the Created At field to the current date/time -func (b *BaseModel) RefreshCreatedAt() { - b.CreatedAt = time.Now() +func (b *Base) RefreshCreatedAt() { + b.CreatedAt = types.NowDateTime() } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // RefreshUpdatedAt updates the Updated At field to the current date/time -func (b *BaseModel) RefreshUpdatedAt() { - b.UpdatedAt = time.Now() +func (b *Base) RefreshUpdatedAt() { + b.UpdatedAt = types.NowDateTime() } diff --git a/models/base_test.go b/models/base_test.go index 2b2dd33..56fc17d 100644 --- a/models/base_test.go +++ b/models/base_test.go @@ -9,8 +9,8 @@ import ( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -func TestRefreshId(t *testing.T) { - model := BaseModel{} +func Test_RefreshId(t *testing.T) { + model := Base{} model.RefreshId() require.NotEmpty(t, model.ID) @@ -25,16 +25,16 @@ func TestRefreshId(t *testing.T) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -func TestSetId(t *testing.T) { - model := BaseModel{} +func Test_SetId(t *testing.T) { + model := Base{} model.SetId("testId") require.Equal(t, "testId", model.ID) } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -func TestRefreshCreatedAt(t *testing.T) { - model := BaseModel{} +func Test_RefreshCreatedAt(t *testing.T) { + model := Base{} model.RefreshCreatedAt() require.False(t, model.CreatedAt.IsZero()) @@ -50,8 +50,8 @@ func TestRefreshCreatedAt(t *testing.T) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -func TestRefreshUpdatedAt(t *testing.T) { - model := BaseModel{} +func Test_RefreshUpdatedAt(t *testing.T) { + model := Base{} model.RefreshUpdatedAt() require.False(t, model.UpdatedAt.IsZero()) diff --git a/models/course.go b/models/course.go index 25c8687..1229472 100644 --- a/models/course.go +++ b/models/course.go @@ -1,30 +1,65 @@ package models -import "time" +import ( + "fmt" -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + "github.com/geerew/off-course/utils/schema" + "github.com/geerew/off-course/utils/types" +) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Course defines the model for a course (table: courses) +// Course defines the model for a course type Course struct { - BaseModel + Base Title string Path string CardPath string Available bool - // -------------------------------- - // Not in this table, but added via join - // -------------------------------- + // Joins + ScanStatus types.ScanStatus + + // Relations + Progress CourseProgress +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + COURSE_TABLE = "courses" + COURSE_TITLE = "title" + COURSE_PATH = "path" + COURSE_CARD_PATH = "card_path" + COURSE_AVAILABLE = "available" + COURSE_SCAN_STATUS = "status" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `schema.Modeler` interface by returning the table name +func (c *Course) Table() string { + return COURSE_TABLE +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Fields implements the `schema.Modeler` interface by defining the model fields +func (c *Course) Define(s *schema.ModelConfig) { + s.Embedded("Base") + + // Common fields + s.Field("Title").Column(COURSE_TITLE).NotNull() + s.Field("Path").Column(COURSE_PATH).NotNull() + s.Field("CardPath").Column(COURSE_CARD_PATH).Mutable() + s.Field("Available").Column(COURSE_AVAILABLE).Mutable() + + // Join fields + s.Field("ScanStatus").JoinTable(SCAN_TABLE).Column(COURSE_SCAN_STATUS).Alias("scan_status") - // Scan status - ScanStatus string + // Relation fields + s.Relation("Progress").MatchOn(COURSE_PROGRESS_COURSE_ID) - // Course Progress - Started bool - StartedAt time.Time - Percent int - CompletedAt time.Time - ProgressUpdatedAt time.Time + // Joins + s.LeftJoin(SCAN_TABLE).On(fmt.Sprintf("%s.%s = %s.%s", COURSE_TABLE, BASE_ID, SCAN_TABLE, SCAN_COURSE_ID)) } diff --git a/models/course_progress.go b/models/course_progress.go index f66a622..3222ff0 100644 --- a/models/course_progress.go +++ b/models/course_progress.go @@ -1,18 +1,50 @@ package models -import "time" +import ( + "github.com/geerew/off-course/utils/schema" + "github.com/geerew/off-course/utils/types" +) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// CourseProgress defines the model for a course progress (table: courses_progress) +// CourseProgress defines the model for a course progress type CourseProgress struct { - BaseModel - + Base CourseID string Started bool - StartedAt time.Time + StartedAt types.DateTime Percent int - CompletedAt time.Time + CompletedAt types.DateTime +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + COURSE_PROGRESS_TABLE = "courses_progress" + COURSE_PROGRESS_COURSE_ID = "course_id" + COURSE_PROGRESS_STARTED = "started" + COURSE_PROGRESS_STARTED_AT = "started_at" + COURSE_PROGRESS_PERCENT = "percent" + COURSE_PROGRESS_COMPLETED_AT = "completed_at" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `schema.Modeler` interface by returning the table name +func (cp *CourseProgress) Table() string { + return COURSE_PROGRESS_TABLE +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Fields implements the `schema.Modeler` interface by defining the model fields +func (cp *CourseProgress) Define(s *schema.ModelConfig) { + s.Embedded("Base") + + // Common fields + s.Field("CourseID").Column(COURSE_PROGRESS_COURSE_ID).NotNull() + s.Field("Started").Column(COURSE_PROGRESS_STARTED).Mutable() + s.Field("StartedAt").Column(COURSE_PROGRESS_STARTED_AT).Mutable() + s.Field("Percent").Column(COURSE_PROGRESS_PERCENT).Mutable() + s.Field("CompletedAt").Column(COURSE_PROGRESS_COMPLETED_AT).Mutable() } diff --git a/models/course_tag.go b/models/course_tag.go index c27742c..e4fe8c0 100644 --- a/models/course_tag.go +++ b/models/course_tag.go @@ -1,18 +1,54 @@ package models -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +import ( + "fmt" -// CourseTag defines the model for a course tag (table: courses_tags) -type CourseTag struct { - BaseModel + "github.com/geerew/off-course/utils/schema" +) - TagId string - CourseId string +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // -------------------------------- - // Not in this table, but added via join - // -------------------------------- +// CourseTag defines the model for a course tag +type CourseTag struct { + Base + TagID string + CourseID string + // Joins Course string Tag string } + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + COURSE_TAG_TABLE = "courses_tags" + COURSE_TAG_TAG_ID = "tag_id" + COURSE_TAG_COURSE_ID = "course_id" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `schema.Modeler` interface by returning the table name +func (ct *CourseTag) Table() string { + return COURSE_TAG_TABLE +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Define implements the `schema.Modeler` interface by defining the model +func (ct *CourseTag) Define(c *schema.ModelConfig) { + c.Embedded("Base") + + // Common fields + c.Field("TagID").Column(COURSE_TAG_TAG_ID).NotNull() + c.Field("CourseID").Column(COURSE_TAG_COURSE_ID).NotNull() + + // Join field + c.Field("Course").JoinTable(COURSE_TABLE).Column(COURSE_TITLE).Alias("course_title") + c.Field("Tag").JoinTable(TAG_TABLE).Column(TAG_TAG).Alias("tag_tag") + + c.LeftJoin(COURSE_TABLE).On(fmt.Sprintf("%s.%s = %s.%s", COURSE_TAG_TABLE, COURSE_TAG_COURSE_ID, COURSE_TABLE, BASE_ID)) + c.LeftJoin(TAG_TABLE).On(fmt.Sprintf("%s.%s = %s.%s", COURSE_TAG_TABLE, COURSE_TAG_TAG_ID, TAG_TABLE, BASE_ID)) + +} diff --git a/models/log.go b/models/log.go index f6a21ad..07af477 100644 --- a/models/log.go +++ b/models/log.go @@ -1,14 +1,44 @@ package models -import "github.com/geerew/off-course/utils/types" +import ( + "github.com/geerew/off-course/utils/schema" + "github.com/geerew/off-course/utils/types" +) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Log defines the model for a log (table: logs) +// Log defines the model for a log type Log struct { - BaseModel - - Data types.JsonMap + Base Level int Message string + Data types.JsonMap +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + LOG_TABLE = "logs" + LOG_LEVEL = "level" + LOG_MESSAGE = "message" + LOG_DATA = "data" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `schema.Modeler` interface by returning the table name +func (l *Log) Table() string { + return LOG_TABLE +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Fields implements the `schema.Modeler` interface by defining the model fields +func (l *Log) Define(s *schema.ModelConfig) { + s.Embedded("Base") + + // Common fields + s.Field("Level").Column(LOG_LEVEL) + s.Field("Message").Column(LOG_MESSAGE).NotNull() + s.Field("Data").Column(LOG_DATA).NotNull() } diff --git a/models/params.go b/models/params.go new file mode 100644 index 0000000..472731a --- /dev/null +++ b/models/params.go @@ -0,0 +1,38 @@ +package models + +import "github.com/geerew/off-course/utils/schema" + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Param defines the model for a parameter +type Param struct { + Base + Key string + Value string +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + PARAM_TABLE = "params" + PARAM_KEY = "key" + PARAM_VALUE = "value" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `schema.Modeler` interface by returning the table name +func (p *Param) Table() string { + return PARAM_TABLE +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Fields implements the `schema.Modeler` interface by defining the model fields +func (p *Param) Define(s *schema.ModelConfig) { + s.Embedded("Base") + + // Common fields + s.Field("Key").Column(PARAM_KEY).NotNull() + s.Field("Value").Column(PARAM_VALUE).NotNull().Mutable() +} diff --git a/models/scan.go b/models/scan.go index 857599b..ac21d14 100644 --- a/models/scan.go +++ b/models/scan.go @@ -3,21 +3,53 @@ package models // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ import ( + "fmt" + + "github.com/geerew/off-course/utils/schema" "github.com/geerew/off-course/utils/types" ) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Scan defines the model for a scan (table: scans) +// Scan defines the model for a scan type Scan struct { - BaseModel - + Base CourseID string Status types.ScanStatus - // -------------------------------- - // Not in this table, but added via join - // -------------------------------- - + // Joins CoursePath string } + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + SCAN_TABLE = "scans" + SCAN_COURSE_ID = "course_id" + SCAN_STATUS = "status" + SCAN_COURSE_PATH = "path" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `schema.Modeler` interface by returning the table name +func (s *Scan) Table() string { + return SCAN_TABLE +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Define implements the `schema.Modeler` interface by defining the model +func (s *Scan) Define(c *schema.ModelConfig) { + c.Embedded("Base") + + // Common fields + c.Field("CourseID").Column(SCAN_COURSE_ID) + c.Field("Status").Column(SCAN_STATUS).Mutable().IgnoreIfNull() + + // Join fields + c.Field("CoursePath").JoinTable(COURSE_TABLE).Column(SCAN_COURSE_PATH).Alias("course_path") + + // Joins + c.LeftJoin(COURSE_TABLE).On(fmt.Sprintf("%s.%s = %s.%s", SCAN_TABLE, SCAN_COURSE_ID, COURSE_TABLE, BASE_ID)) +} diff --git a/models/tag.go b/models/tag.go index 8d98141..c5d9527 100644 --- a/models/tag.go +++ b/models/tag.go @@ -1,18 +1,48 @@ package models +import "github.com/geerew/off-course/utils/schema" + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Tag defines the model for a tag (table: tags) +// Tag defines the model for a tag type Tag struct { - BaseModel - + Base Tag string - // -------------------------------- - // Not in this table, but added via a join - // -------------------------------- + // Joins + // CourseCount int + + // Relations + CourseTags []*CourseTag +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + TAG_TABLE = "tags" + TAG_TAG = "tag" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `schema.Modeler` interface by returning the table name +func (t *Tag) Table() string { + return TAG_TABLE +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Define implements the `schema.Modeler` interface by defining the model +func (t *Tag) Define(s *schema.ModelConfig) { + s.Embedded("Base") + + // Common fields + s.Field("Tag").Column(TAG_TAG).NotNull().Mutable() + + // Relation fields + s.Relation("CourseTags").MatchOn(COURSE_TAG_TAG_ID) + + // c.Field("CourseCount").AggregateColumn("COALESCE(COUNT(courses_tags.id), 0)") + // c.LeftJoin(COURSE_TABLE).On(fmt.Sprintf("%s.%s = %s.%s", SCAN_TABLE, SCAN_COURSE_ID, COURSE_TABLE, BASE_ID)) - // Courses - CourseCount int - CourseTags []*CourseTag } diff --git a/models/user.go b/models/user.go new file mode 100644 index 0000000..e7187af --- /dev/null +++ b/models/user.go @@ -0,0 +1,45 @@ +package models + +import ( + "github.com/geerew/off-course/utils/schema" + "github.com/geerew/off-course/utils/types" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// User defines the model for a user +type User struct { + Base + + Username string + PasswordHash string + Role types.UserRole +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + USER_TABLE = "users" + USER_USERNAME = "username" + USER_PASSWORD_HASH = "password_hash" + USER_ROLE = "role" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `schema.Modeler` interface by returning the table name +func (u *User) Table() string { + return USER_TABLE +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Fields implements the `schema.Modeler` interface by defining the model fields +func (u *User) Define(s *schema.ModelConfig) { + s.Embedded("Base") + + // Common fields + s.Field("Username").Column(USER_USERNAME).NotNull() + s.Field("PasswordHash").Column(USER_PASSWORD_HASH).NotNull() + s.Field("Role").Column(USER_ROLE).NotNull() +} diff --git a/utils/appFs/appFs_test.go b/utils/appFs/appFs_test.go index 944cfad..0a9aad6 100644 --- a/utils/appFs/appFs_test.go +++ b/utils/appFs/appFs_test.go @@ -53,7 +53,7 @@ func Test_Open(t *testing.T) { appFs.Fs.Create("/a") res, err := appFs.Open("/a") - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, res) }) } @@ -94,7 +94,7 @@ func Test_ReadDir(t *testing.T) { appFs.Fs.Mkdir("/c", 0755) res, err := appFs.ReadDir("/", true) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, res) require.Equal(t, 2, len(res.Files)) require.Equal(t, 1, len(res.Directories)) @@ -151,25 +151,25 @@ func Test_ReadDirFlat(t *testing.T) { // Depth 0 (same as 1) res, err := appFs.ReadDirFlat("/", 0) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, res) require.Equal(t, 1, len(res)) // Depth 1 res, err = appFs.ReadDirFlat("/", 1) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, res) require.Equal(t, 1, len(res)) // Depth 10 res, err = appFs.ReadDirFlat("/", 2) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, res) require.Equal(t, 5, len(res)) // Depth 10 res, err = appFs.ReadDirFlat("/", 10) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, res) require.Equal(t, 11, len(res)) }) @@ -185,7 +185,7 @@ func Test_NonWslDrives(t *testing.T) { paths := []string{} for _, p := range paths { err := appFs.Fs.MkdirAll(p, os.ModePerm) - require.Nil(t, err) + require.NoError(t, err) } drives, err := appFs.nonWslDrives() @@ -194,7 +194,7 @@ func Test_NonWslDrives(t *testing.T) { t.Skip("not implemented") } - require.Nil(t, err) + require.NoError(t, err) require.NotEmpty(t, drives) }) } @@ -208,7 +208,7 @@ func Test_WslDrives(t *testing.T) { paths := []string{} for _, p := range paths { err := appFs.Fs.MkdirAll(p, os.ModePerm) - require.Nil(t, err) + require.NoError(t, err) } drives, err := appFs.wslDrives() @@ -227,11 +227,11 @@ func Test_WslDrives(t *testing.T) { paths := []string{"/mnt/c", "/mnt/d", "/mnt/wsl", "/mnt/wslg"} for _, p := range paths { err := appFs.Fs.MkdirAll(p, os.ModePerm) - require.Nil(t, err) + require.NoError(t, err) } drives, err := appFs.wslDrives() - require.Nil(t, err) + require.NoError(t, err) require.Len(t, drives, 3) require.ElementsMatch(t, []string{"/", filepath.Join("/mnt", "c"), filepath.Join("/mnt", "d")}, drives) }) @@ -263,7 +263,7 @@ func Test_PartialHash(t *testing.T) { require.Nil(t, afero.WriteFile(appFs.Fs, "/test/"+tt.name, tt.content, 0644)) hash, err := appFs.PartialHash("/test/"+tt.name, 1024*1024) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, tt.expected, hash) }) } @@ -275,14 +275,14 @@ func Test_PartialHash(t *testing.T) { require.Nil(t, afero.WriteFile(appFs.Fs, "/test/data", []byte("Some test data"), 0644)) hash, err := appFs.PartialHash("/test/data", 1024*1024) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, "0843f7816915fae7fc9c31dbbb3e8745015b53a297930e522d544c13287cb062", hash) // Rename the file require.Nil(t, appFs.Fs.Rename("/test/data", "/test/newdata")) hash, err = appFs.PartialHash("/test/newdata", 1024*1024) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, "0843f7816915fae7fc9c31dbbb3e8745015b53a297930e522d544c13287cb062", hash) }) } diff --git a/utils/coursescan/base.go b/utils/coursescan/base.go new file mode 100644 index 0000000..7c51ab3 --- /dev/null +++ b/utils/coursescan/base.go @@ -0,0 +1,681 @@ +package coursescan + +import ( + "context" + "database/sql" + "fmt" + "log/slog" + "os" + "path/filepath" + "regexp" + "strconv" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/dao" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/appFs" + "github.com/geerew/off-course/utils/security" + "github.com/geerew/off-course/utils/types" +) + +var ( + loggerType = slog.Any("type", types.LogTypeScanner) +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CourseScanProcessorFn is a function that processes a course scan job +type CourseScanProcessorFn func(context.Context, *CourseScan, *models.Scan) error + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CourseScan scans a course and finds assets and attachments +type CourseScan struct { + appFs *appFs.AppFs + db database.Database + dao *dao.DAO + logger *slog.Logger + jobSignal chan bool +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CourseScanConfig is the config for a CourseScan +type CourseScanConfig struct { + Db database.Database + AppFs *appFs.AppFs + Logger *slog.Logger +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// NewCourseScan creates a new CourseScan +func NewCourseScan(config *CourseScanConfig) *CourseScan { + return &CourseScan{ + appFs: config.AppFs, + db: config.Db, + dao: dao.NewDAO(config.Db), + logger: config.Logger, + jobSignal: make(chan bool, 1), + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Add inserts a course scan job into the db +func (s *CourseScan) Add(ctx context.Context, courseId string) (*models.Scan, error) { + // Check if the course exists + course := &models.Course{Base: models.Base{ID: courseId}} + err := s.dao.GetById(ctx, course) + if err != nil { + if err == sql.ErrNoRows { + return nil, utils.ErrInvalidId + } + + return nil, err + } + + // Do nothing when a scan job is already in progress + if course.ScanStatus.IsWaiting() || course.ScanStatus.IsProcessing() { + s.logger.Debug( + "Scan already in progress", + loggerType, + slog.String("path", course.Path), + ) + + // Get the scan from the db and return that + scan := &models.Scan{} + err := s.dao.Get(ctx, scan, &database.Options{Where: squirrel.Eq{scan.Table() + ".course_id": courseId}}) + if err != nil { + return nil, err + } + + return scan, nil + } + + // Add the job + scan := &models.Scan{CourseID: courseId, Status: types.NewScanStatusWaiting()} + if err := s.dao.CreateScan(ctx, scan); err != nil { + return nil, err + } + + // Signal the worker to process the job + select { + case s.jobSignal <- true: + default: + } + + s.logger.Info( + "Added scan job", + loggerType, + slog.String("path", course.Path), + ) + + return scan, nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Worker processes jobs out of the DB sequentially +func (s *CourseScan) Worker(ctx context.Context, processorFn CourseScanProcessorFn, processingDone chan bool) { + s.logger.Debug("Started course scanner worker", loggerType) + + for { + <-s.jobSignal + + // Keep process jobs from the scans table until there are no more jobs + for { + nextScan := &models.Scan{} + err := s.dao.NextWaitingScan(ctx, nextScan) + if err != nil { + // Nothing more to process + if err == sql.ErrNoRows { + s.logger.Debug("Finished processing all scan jobs", loggerType) + break + } + + // Error + s.logger.Error( + "Failed to look up the next scan job", + loggerType, + slog.String("error", err.Error()), + ) + + break + } + + s.logger.Info( + "Processing scan job", + loggerType, + slog.String("job", nextScan.ID), + slog.String("path", nextScan.CoursePath), + ) + + err = processorFn(ctx, s, nextScan) + if err != nil { + s.logger.Error( + "Failed to process scan job", + loggerType, + slog.String("error", err.Error()), + slog.String("path", nextScan.CoursePath), + ) + } + + // Cleanup + if err := s.dao.Delete(ctx, nextScan, nil); err != nil { + s.logger.Error( + "Failed to delete scan job", + loggerType, + slog.String("error", err.Error()), + slog.String("job", nextScan.ID), + ) + + break + } + } + + // Signal that processing is done + if processingDone != nil { + processingDone <- true + } + + // Clear any pending signal that were sent while processing + select { + case <-s.jobSignal: + default: + } + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +type assetMap map[string]map[int]*models.Asset +type attachmentMap map[string]map[int][]*models.Attachment + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Processor scans a course to identify assets and attachments +func Processor(ctx context.Context, s *CourseScan, scan *models.Scan) error { + if scan == nil { + return ErrNilScan + } + + // Set the scan status to processing + scan.Status.SetProcessing() + err := s.dao.UpdateScan(ctx, scan) + if err != nil { + return err + } + + // Get the course for this scan + course := &models.Course{} + err = s.dao.Get(ctx, course, &database.Options{Where: squirrel.Eq{course.Table() + ".id": scan.CourseID}}) + if err != nil { + if err == sql.ErrNoRows { + s.logger.Debug( + "Ignoring scan job as the course no longer exists", + loggerType, + slog.String("path", scan.CoursePath), + ) + + return nil + } + + return err + } + + // Check availability and skip when unavailable. Also marks the course as unavailable + _, err = s.appFs.Fs.Stat(course.Path) + if err != nil { + if os.IsNotExist(err) { + s.logger.Debug( + "Skipping as the course is unavailable", + loggerType, + slog.String("path", scan.CoursePath), + ) + + if course.Available { + course.Available = false + err = s.dao.UpdateCourse(ctx, course) + if err != nil { + return err + } + } + + return nil + } + + return err + } + + // If the course is currently marked as unavailable, set it as available + if !course.Available { + course.Available = true + err := s.dao.UpdateCourse(ctx, course) + if err != nil { + return err + } + + s.logger.Debug( + "Setting unavailable course as available", + loggerType, + slog.String("path", scan.CoursePath), + ) + } + + cardPath := "" + + // Get all files down to a depth of 2 + files, err := s.appFs.ReadDirFlat(course.Path, 2) + if err != nil { + return err + } + + // Maps to hold assets and attachments by [chapter][prefix] + assetsMap := assetMap{} + attachmentsMap := attachmentMap{} + + for _, fp := range files { + normalizedPath := utils.NormalizeWindowsDrive(fp) + filename := filepath.Base(normalizedPath) + fileDir := filepath.Dir(normalizedPath) + isInRoot := fileDir == utils.NormalizeWindowsDrive(course.Path) + + // Check if this file is the course card + if isInRoot && isCard(filename) { + if cardPath != "" { + s.logger.Debug( + "Found another course card. Ignoring", + loggerType, + slog.String("file", filename), + slog.String("path", scan.CoursePath), + ) + } else { + cardPath = normalizedPath + } + + continue + } + + // Set the chapter. This will be empty when the file is in the root directory + chapter := "" + if !isInRoot { + chapter = filepath.Base(fileDir) + } + + if _, exists := assetsMap[chapter]; !exists { + assetsMap[chapter] = make(map[int]*models.Asset) + } + + if _, exists := attachmentsMap[chapter]; !exists { + attachmentsMap[chapter] = make(map[int][]*models.Attachment) + } + + pfn := parseFilename(filename) + + // Ignore files that are neither assets nor attachments + if pfn == nil { + s.logger.Debug( + "Incompatible file name. Ignoring", + loggerType, + slog.String("path", scan.CoursePath), + slog.String("file", normalizedPath), + ) + + continue + } + + // Add attachment + if pfn.asset == nil { + attachmentsMap[chapter][pfn.prefix] = append( + attachmentsMap[chapter][pfn.prefix], + &models.Attachment{ + Title: pfn.title, + Path: normalizedPath, + CourseID: course.ID, + }, + ) + + continue + } + + newAsset := &models.Asset{ + Title: pfn.title, + Prefix: sql.NullInt16{Int16: int16(pfn.prefix), Valid: true}, + CourseID: course.ID, + Chapter: chapter, + Path: normalizedPath, + Type: *pfn.asset, + } + + existing, exists := assetsMap[chapter][pfn.prefix] + + if !exists { + // Add asset + hash, err := s.appFs.PartialHash(normalizedPath, 1024*1024) + if err != nil { + return err + } + newAsset.Hash = hash + + assetsMap[chapter][pfn.prefix] = newAsset + + } else { + // Check if this new asset has a higher priority than the existing asset. The priority + // is video > html > pdf + if newAsset.Type.IsVideo() && !existing.Type.IsVideo() || + newAsset.Type.IsHTML() && existing.Type.IsPDF() { + + // Demote the existing asset to an attachment and add the new asset + s.logger.Debug( + "Found a higher priority asset. Replacing", + loggerType, + slog.String("path", scan.CoursePath), + slog.String("file", normalizedPath), + ) + + hash, err := s.appFs.PartialHash(normalizedPath, 1024*1024) + if err != nil { + return err + } + newAsset.Hash = hash + + assetsMap[chapter][pfn.prefix] = newAsset + + attachmentsMap[chapter][pfn.prefix] = append( + attachmentsMap[chapter][pfn.prefix], + &models.Attachment{ + Title: existing.Title + filepath.Ext(existing.Path), + Path: existing.Path, + CourseID: course.ID, + }, + ) + } else { + // Add the new asset as an attachment + attachmentsMap[chapter][pfn.prefix] = append( + attachmentsMap[chapter][pfn.prefix], + &models.Attachment{ + Title: pfn.title, + Path: normalizedPath, + CourseID: course.ID, + }, + ) + } + } + } + + course.CardPath = cardPath + + return s.db.RunInTransaction(ctx, func(txCtx context.Context) error { + // Convert the assets map to a slice + assets := make([]*models.Asset, 0, len(files)) + for _, chapterMap := range assetsMap { + for _, asset := range chapterMap { + assets = append(assets, asset) + } + } + + fmt.Println("assets", assets) + + // Update the assets in DB + if len(assets) > 0 { + err = updateAssets(txCtx, s.dao, course.ID, assets) + if err != nil { + return err + } + } + + // Convert the attachments map to a slice + attachments := []*models.Attachment{} + for chapter, attachmentMap := range attachmentsMap { + for prefix, potentialAttachments := range attachmentMap { + // Only add attachments when there is an assert + if asset, exists := assetsMap[chapter][prefix]; exists { + for _, attachment := range potentialAttachments { + attachment.AssetID = asset.ID + attachments = append(attachments, attachment) + } + } + } + } + + // Update the attachments in DB + if len(attachments) > 0 { + err = updateAttachments(txCtx, s.dao, course.ID, attachments) + if err != nil { + return err + } + } + + err = s.dao.UpdateCourse(txCtx, course) + if err != nil { + return err + } + + return nil + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// PRIVATE +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// parsedFilename that holds information following a filename being parsed +type parsedFilename struct { + prefix int + title string + asset *types.Asset +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// A regex for parsing a file name into a prefix, title, and extension +// +// Valid patterns: +// +// `` +// `.` +// ` ` +// `<prefix>-<title>` +// `<prefix> - <title>` +// `<prefix> <title>.<ext>` +// `<prefix>-<title>.<ext>` +// `<prefix> - <title>.<ext>` +// +// - <prefix> is required and must be a number +// - A dash (-) is optional +// - <title> is optional and can be any non-empty string +// - <ext> is optional +var filenameRegex = regexp.MustCompile(`^\s*(?P<Prefix>[0-9]+)((?:\s+-+\s+|\s+-+|\s+|-+\s*)(?P<Title>[^.][^.]*)?)?(?:\.(?P<Ext>\w+))?$`) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// parseFilename parses a file name and determines if it represents an asset, attachment, or neither +// +// A file is an asset when it matches `<prefix> <title>.<ext>` and <ext> is a valid `types.AssetType` +// +// A file is an attachment when it has a <prefix>, and optionally a <title> and/or <ext>, whereby <ext> +// is not a valid `types.AssetType` +// +// When a file is neither an asset nor an attachment, nil is returned +func parseFilename(filename string) *parsedFilename { + pfn := &parsedFilename{} + + matches := filenameRegex.FindStringSubmatch(filename) + if len(matches) == 0 { + return nil + } + + prefix, err := strconv.Atoi(matches[filenameRegex.SubexpIndex("Prefix")]) + if err != nil { + return nil + } + + pfn.prefix = prefix + pfn.title = matches[filenameRegex.SubexpIndex("Title")] + + // When title is empty, consider this an attachment + if pfn.title == "" { + pfn.title = filename + return pfn + } + + // Where there is no extension, consider this an attachment + ext := matches[filenameRegex.SubexpIndex("Ext")] + if ext == "" { + return pfn + } + + pfn.asset = types.NewAsset(ext) + + // When the extension is not supported, consider this an attachment + if pfn.asset == nil { + pfn.title = pfn.title + "." + ext + } + + return pfn +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// isCard determines if a given file name represents a card based on its name and extension +func isCard(filename string) bool { + // Get the extension. If there is no extension, return false + ext := filepath.Ext(filename) + if ext == "" { + return false + } + + fileWithoutExt := filename[:len(filename)-len(ext)] + if fileWithoutExt != "card" { + return false + } + + // Check if the extension is supported + switch ext[1:] { + case + "jpg", + "jpeg", + "png", + "webp", + "tiff": + return true + } + + return false +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// updateAssets updates the assets in the database based on the assets found on disk. It compares +// the existing assets in the database with the assets found on disk, and performs the necessary +// additions, deletions, and updates +func updateAssets(ctx context.Context, dao *dao.DAO, courseId string, assets []*models.Asset) error { + existingAssets := []*models.Asset{} + err := dao.List(ctx, &existingAssets, &database.Options{Where: squirrel.Eq{models.ASSET_TABLE + ".course_id": courseId}}) + if err != nil { + return err + } + + // Compare the assets found on disk to assets found in DB and identify which assets to add and + // which assets to delete + toAdd, toDelete, err := utils.DiffSliceOfStructsByKey(assets, existingAssets, "Hash") + if err != nil { + return err + } + + // Add assets + // TODO: This could be optimized by using a bulk insert + for _, asset := range toAdd { + if err := dao.CreateAsset(ctx, asset); err != nil { + return err + } + } + + // Delete assets + for _, deleteAsset := range toDelete { + err := dao.Delete(ctx, deleteAsset, nil) + if err != nil { + return err + } + } + + // Identify the existing assets whose information has changed + existingAssetsMap := make(map[string]*models.Asset) + for _, existingAsset := range existingAssets { + existingAssetsMap[existingAsset.Hash] = existingAsset + } + + randomTempSuffix := security.RandomString(10) + updatedAssets := make([]*models.Asset, 0, len(assets)) + + // On the first pass we update the existing assets with details of the new asset. In addition, we + // set the path to be path+randomTempSuffix. This is to prevent a `unique path constraint` error if, + // for example, 2 files are have their titles swapped. + // + // On the second pass we update the existing assets and remove the randomTempSuffix from the path + for _, asset := range assets { + if existingAsset, exists := existingAssetsMap[asset.Hash]; exists { + asset.ID = existingAsset.ID + + if !utils.CompareStructs(asset, existingAsset, []string{"CreatedAt", "UpdatedAt"}) { + asset.Path = asset.Path + randomTempSuffix + updatedAssets = append(updatedAssets, asset) + + // The assets has been updated to have the existing assets ID, so this will update the + // existing asset with the details of the new asset + if err := dao.UpdateAsset(ctx, asset); err != nil { + return err + } + } + } + } + + for _, asset := range updatedAssets { + asset.Path = asset.Path[:len(asset.Path)-len(randomTempSuffix)] + + if err := dao.UpdateAsset(ctx, asset); err != nil { + return err + } + } + + return nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// updateAttachments updates the attachments in the database based on the attachments found on disk. +// It compares the existing attachments in the database with the attachments found on disk, and performs +// the necessary additions and deletions +func updateAttachments(ctx context.Context, dao *dao.DAO, courseId string, attachments []*models.Attachment) error { + existingAttachments := []*models.Attachment{} + err := dao.List(ctx, &existingAttachments, &database.Options{Where: squirrel.Eq{models.ATTACHMENT_TABLE + ".course_id": courseId}}) + if err != nil { + return err + } + + // Compare the attachments found on disk to attachments found in DB + toAdd, toDelete, err := utils.DiffSliceOfStructsByKey(attachments, existingAttachments, "Path") + if err != nil { + return err + } + + // Add attachments + for _, attachment := range toAdd { + if err := dao.CreateAttachment(ctx, attachment); err != nil { + return err + } + } + + // Delete attachments + for _, attachment := range toDelete { + err := dao.Delete(ctx, attachment, nil) + if err != nil { + return err + } + } + + return nil +} diff --git a/utils/coursescan/base_test.go b/utils/coursescan/base_test.go new file mode 100644 index 0000000..620d000 --- /dev/null +++ b/utils/coursescan/base_test.go @@ -0,0 +1,1106 @@ +package coursescan + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/models" + "github.com/geerew/off-course/utils" + "github.com/geerew/off-course/utils/appFs" + "github.com/geerew/off-course/utils/logger" + "github.com/geerew/off-course/utils/security" + "github.com/geerew/off-course/utils/types" + "github.com/spf13/afero" + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func setup(t *testing.T) (*CourseScan, context.Context, *[]*logger.Log) { + t.Helper() + + // Logger + var logs []*logger.Log + var logsMux sync.Mutex + logger, _, err := logger.InitLogger(&logger.BatchOptions{ + BatchSize: 1, + WriteFn: logger.TestWriteFn(&logs, &logsMux), + }) + require.NoError(t, err, "Failed to initialize logger") + + appFs := appFs.NewAppFs(afero.NewMemMapFs(), logger) + + dbManager, err := database.NewSqliteDBManager(&database.DatabaseConfig{ + IsDebug: false, + DataDir: "./oc_data", + AppFs: appFs, + InMemory: true, + }) + + require.NoError(t, err) + require.NotNil(t, dbManager) + + courseScan := NewCourseScan(&CourseScanConfig{ + Db: dbManager.DataDb, + AppFs: appFs, + Logger: logger, + }) + + return courseScan, context.Background(), &logs +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestScanner_Add(t *testing.T) { + t.Run("success", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course1 := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course1)) + + scan1, err := scanner.Add(ctx, course1.ID) + require.NoError(t, err) + require.Equal(t, course1.ID, scan1.CourseID) + + course2 := &models.Course{Title: "Course 2", Path: "/course-2"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course2)) + + scan2, err := scanner.Add(ctx, course2.ID) + require.NoError(t, err) + require.Equal(t, course2.ID, scan2.CourseID) + }) + + t.Run("duplicate", func(t *testing.T) { + scanner, ctx, logs := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + first, err := scanner.Add(ctx, course.ID) + require.NoError(t, err) + require.Equal(t, course.ID, first.CourseID) + + // Add again + second, err := scanner.Add(ctx, course.ID) + require.NoError(t, err) + require.Equal(t, second.ID, first.ID) + require.NotEmpty(t, *logs) + require.Equal(t, "Scan already in progress", (*logs)[len(*logs)-1].Message) + require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) + }) + + t.Run("invalid course", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + scan, err := scanner.Add(ctx, "1234") + require.ErrorIs(t, err, utils.ErrInvalidId) + require.Nil(t, scan) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestScanner_Worker(t *testing.T) { + t.Run("jobs", func(t *testing.T) { + scanner, ctx, logs := setup(t) + + courses := []*models.Course{} + for i := range 3 { + course := &models.Course{Title: fmt.Sprintf("course %d", i), Path: fmt.Sprintf("/course-%d", i)} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + courses = append(courses, course) + } + + var processingDone = make(chan bool, 1) + go scanner.Worker(ctx, func(context.Context, *CourseScan, *models.Scan) error { + time.Sleep(1 * time.Millisecond) + return nil + }, processingDone) + + // Add the courses + for i := range 3 { + scan, err := scanner.Add(ctx, courses[i].ID) + require.NoError(t, err) + require.Equal(t, scan.CourseID, courses[i].ID) + } + + // Wait for the worker to finish + <-processingDone + + // Sometimes the delete is slow to happen + time.Sleep(20 * time.Millisecond) + + count, err := scanner.dao.Count(ctx, &models.Scan{}, nil) + require.NoError(t, err) + require.Zero(t, count) + + require.NotEmpty(t, *logs) + require.Equal(t, "Finished processing all scan jobs", (*logs)[len(*logs)-1].Message) + require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) + + // Add the first 2 courses (again) + for i := range 2 { + scan, err := scanner.Add(ctx, courses[i].ID) + require.NoError(t, err) + require.Equal(t, scan.CourseID, courses[i].ID) + } + + // Wait for the worker to finish + <-processingDone + + count, err = scanner.dao.Count(ctx, &models.Scan{}, nil) + require.NoError(t, err) + require.Zero(t, count) + + require.NotEmpty(t, *logs) + require.Equal(t, "Finished processing all scan jobs", (*logs)[len(*logs)-1].Message) + require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) + }) + + t.Run("error processing", func(t *testing.T) { + scanner, ctx, logs := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + var processingDone = make(chan bool, 1) + go scanner.Worker(ctx, func(context.Context, *CourseScan, *models.Scan) error { + time.Sleep(1 * time.Millisecond) + return errors.New("processing error") + }, processingDone) + + scan, err := scanner.Add(ctx, course.ID) + require.NoError(t, err) + require.Equal(t, scan.CourseID, course.ID) + + // Wait for the worker to finish + <-processingDone + + require.NotEmpty(t, *logs) + require.Greater(t, len(*logs), 2) + require.Equal(t, "Failed to process scan job", (*logs)[len(*logs)-2].Message) + require.Equal(t, slog.LevelError, (*logs)[len(*logs)-2].Level) + require.Equal(t, "Finished processing all scan jobs", (*logs)[len(*logs)-1].Message) + require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestScanner_Processor(t *testing.T) { + t.Run("scan nil", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + err := Processor(ctx, scanner, nil) + require.ErrorIs(t, err, ErrNilScan) + }) + + t.Run("error getting course", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID, Status: types.NewScanStatusWaiting()} + require.NoError(t, scanner.dao.CreateScan(ctx, scan)) + + // Drop the table + _, err := scanner.db.Exec("DROP TABLE IF EXISTS " + course.Table()) + require.NoError(t, err) + + err = Processor(ctx, scanner, scan) + require.ErrorContains(t, err, fmt.Sprintf("no such table: %s", course.Table())) + }) + + t.Run("course unavailable", func(t *testing.T) { + scanner, ctx, logs := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1", Available: true} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID, Status: types.NewScanStatusWaiting()} + require.NoError(t, scanner.dao.CreateScan(ctx, scan)) + + err := Processor(ctx, scanner, scan) + require.NoError(t, err) + + require.NotEmpty(t, *logs) + require.Equal(t, "Skipping as the course is unavailable", (*logs)[len(*logs)-1].Message) + require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) + }) + + t.Run("mark course available", func(t *testing.T) { + scanner, ctx, logs := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1", Available: false} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID, Status: types.NewScanStatusWaiting()} + require.NoError(t, scanner.dao.CreateScan(ctx, scan)) + + scanner.appFs.Fs.Mkdir(course.Path, os.ModePerm) + + err := Processor(ctx, scanner, scan) + require.NoError(t, err) + + require.NotEmpty(t, *logs) + require.Equal(t, "Setting unavailable course as available", (*logs)[len(*logs)-1].Message) + require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) + }) + + t.Run("card", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID, Status: types.NewScanStatusWaiting()} + require.NoError(t, scanner.dao.CreateScan(ctx, scan)) + + scanner.appFs.Fs.Mkdir(course.Path, os.ModePerm) + scanner.appFs.Fs.Create(filepath.Join(course.Path, "card.jpg")) + + err := Processor(ctx, scanner, scan) + require.NoError(t, err) + + courseResult := &models.Course{Base: models.Base{ID: course.ID}} + err = scanner.dao.GetById(ctx, courseResult) + require.NoError(t, err) + require.Equal(t, filepath.Join(course.Path, "card.jpg"), courseResult.CardPath) + + // Ignore card in chapter + scanner.appFs.Fs.Remove(filepath.Join(course.Path, "card.jpg")) + scanner.appFs.Fs.Create(filepath.Join(course.Path, "01 Chapter 1", "card.jpg")) + + err = Processor(ctx, scanner, scan) + require.NoError(t, err) + + courseResult = &models.Course{Base: models.Base{ID: course.ID}} + err = scanner.dao.GetById(ctx, courseResult) + require.NoError(t, err) + require.Empty(t, courseResult.CardPath) + + // Ignore additional cards at the root + scanner.appFs.Fs.Create(filepath.Join(course.Path, "card.jpg")) + scanner.appFs.Fs.Create(filepath.Join(course.Path, "card.png")) + + err = Processor(ctx, scanner, scan) + require.NoError(t, err) + + courseResult = &models.Course{Base: models.Base{ID: course.ID}} + err = scanner.dao.GetById(ctx, courseResult) + require.NoError(t, err) + require.Equal(t, filepath.Join(course.Path, "card.jpg"), courseResult.CardPath) + }) + + t.Run("ignore files", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID, Status: types.NewScanStatusWaiting()} + require.NoError(t, scanner.dao.CreateScan(ctx, scan)) + + scanner.appFs.Fs.Mkdir(course.Path, os.ModePerm) + scanner.appFs.Fs.Create(fmt.Sprintf("%s/file 1", course.Path)) + scanner.appFs.Fs.Create(fmt.Sprintf("%s/file.file", course.Path)) + scanner.appFs.Fs.Create(fmt.Sprintf("%s/file.avi", course.Path)) + scanner.appFs.Fs.Create(fmt.Sprintf("%s/ - file.avi", course.Path)) + scanner.appFs.Fs.Create(fmt.Sprintf("%s/- - file.avi", course.Path)) + scanner.appFs.Fs.Create(fmt.Sprintf("%s/.avi", course.Path)) + scanner.appFs.Fs.Create(fmt.Sprintf("%s/-1 - file.avi", course.Path)) + scanner.appFs.Fs.Create(fmt.Sprintf("%s/a - file.avi", course.Path)) + scanner.appFs.Fs.Create(fmt.Sprintf("%s/1.1 - file.avi", course.Path)) + scanner.appFs.Fs.Create(fmt.Sprintf("%s/2.3-file.avi", course.Path)) + scanner.appFs.Fs.Create(fmt.Sprintf("%s/1file.avi", course.Path)) + + err := Processor(ctx, scanner, scan) + require.NoError(t, err) + + count, err := scanner.dao.Count(ctx, &models.Asset{}, nil) + require.NoError(t, err) + require.Zero(t, count) + }) + + t.Run("assets", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID, Status: types.NewScanStatusWaiting()} + require.NoError(t, scanner.dao.CreateScan(ctx, scan)) + + // Add 2 assets + 1 to ignore + scanner.appFs.Fs.Mkdir(course.Path, os.ModePerm) + afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/01 file 1.mkv", course.Path), []byte("file 1"), os.ModePerm) + afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/02 file 2.html", course.Path), []byte("file 2"), os.ModePerm) + afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/should ignore", course.Path), []byte("ignore"), os.ModePerm) + + err := Processor(ctx, scanner, scan) + require.NoError(t, err) + + options := &database.Options{ + OrderBy: []string{models.ASSET_TABLE + ".chapter asc", models.ASSET_TABLE + ".prefix asc"}, + Where: squirrel.Eq{models.ASSET_TABLE + ".course_id": course.ID}, + } + + assets := []*models.Asset{} + err = scanner.dao.List(ctx, &assets, options) + require.NoError(t, err) + require.Len(t, assets, 2) + + require.Equal(t, "file 1", assets[0].Title) + require.Equal(t, course.ID, assets[0].CourseID) + require.Equal(t, 1, int(assets[0].Prefix.Int16)) + require.Empty(t, assets[0].Chapter) + require.True(t, assets[0].Type.IsVideo()) + require.Equal(t, "ca934260de4b6eb696e4e9912447bc7f2bd7b614da6879b7addef8e03dca71d1", assets[0].Hash) + + require.Equal(t, "file 2", assets[1].Title) + require.Equal(t, course.ID, assets[1].CourseID) + require.Equal(t, 2, int(assets[1].Prefix.Int16)) + require.Empty(t, assets[1].Chapter) + require.True(t, assets[1].Type.IsHTML()) + require.Equal(t, "21b5bfe70ae6b203182d12bdde12f6f086000e37c894187a47b664ea7ec2331a", assets[1].Hash) + + // Delete asset 1 + scanner.appFs.Fs.Remove(fmt.Sprintf("%s/01 file 1.mkv", course.Path)) + + err = Processor(ctx, scanner, scan) + require.NoError(t, err) + + err = scanner.dao.List(ctx, &assets, options) + require.NoError(t, err) + require.Len(t, assets, 1) + + require.Equal(t, "file 2", assets[0].Title) + require.Equal(t, course.ID, assets[0].CourseID) + require.Equal(t, 2, int(assets[0].Prefix.Int16)) + require.Empty(t, assets[0].Chapter) + require.True(t, assets[0].Type.IsHTML()) + require.Equal(t, "21b5bfe70ae6b203182d12bdde12f6f086000e37c894187a47b664ea7ec2331a", assets[0].Hash) + + // Add asset in chapter + afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/01 Chapter 1/01 file 3.pdf", course.Path), []byte("file 3"), os.ModePerm) + + err = Processor(ctx, scanner, scan) + require.NoError(t, err) + + err = scanner.dao.List(ctx, &assets, options) + require.NoError(t, err) + require.Len(t, assets, 2) + + require.Equal(t, "file 2", assets[0].Title) + require.Equal(t, course.ID, assets[0].CourseID) + require.Equal(t, 2, int(assets[0].Prefix.Int16)) + require.Empty(t, assets[0].Chapter) + require.True(t, assets[0].Type.IsHTML()) + require.Equal(t, "21b5bfe70ae6b203182d12bdde12f6f086000e37c894187a47b664ea7ec2331a", assets[0].Hash) + + require.Equal(t, "file 3", assets[1].Title) + require.Equal(t, course.ID, assets[1].CourseID) + require.Equal(t, 1, int(assets[1].Prefix.Int16)) + require.Equal(t, "01 Chapter 1", assets[1].Chapter) + require.True(t, assets[1].Type.IsPDF()) + require.Equal(t, "333940e348f410361b399939d5e120c72896843ad2bea2e5a961cba6818a9ad9", assets[1].Hash) + }) + + t.Run("attachments", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID, Status: types.NewScanStatusWaiting()} + require.NoError(t, scanner.dao.CreateScan(ctx, scan)) + + // Add asset + scanner.appFs.Fs.Mkdir(course.Path, os.ModePerm) + afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/01 file 1.mkv", course.Path), []byte("file 1"), os.ModePerm) + + err := Processor(ctx, scanner, scan) + require.NoError(t, err) + + assetOptions := &database.Options{ + OrderBy: []string{models.ASSET_TABLE + ".chapter asc", models.ASSET_TABLE + ".prefix asc"}, + Where: squirrel.Eq{models.ASSET_TABLE + ".course_id": course.ID}, + } + + assets := []*models.Asset{} + err = scanner.dao.List(ctx, &assets, assetOptions) + require.NoError(t, err) + require.Len(t, assets, 1) + + require.Equal(t, "file 1", assets[0].Title) + require.Equal(t, course.ID, assets[0].CourseID) + require.Equal(t, 1, int(assets[0].Prefix.Int16)) + require.Empty(t, assets[0].Chapter) + require.True(t, assets[0].Type.IsVideo()) + require.Equal(t, "ca934260de4b6eb696e4e9912447bc7f2bd7b614da6879b7addef8e03dca71d1", assets[0].Hash) + + // Add attachment + afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/01 file 1.txt", course.Path), []byte("file 1"), os.ModePerm) + + err = Processor(ctx, scanner, scan) + require.NoError(t, err) + + attachmentOptions := &database.Options{ + OrderBy: []string{models.ATTACHMENT_TABLE + ".created_at asc"}, + Where: squirrel.Eq{models.ATTACHMENT_TABLE + ".course_id": course.ID}, + } + + attachments := []*models.Attachment{} + err = scanner.dao.List(ctx, &attachments, attachmentOptions) + require.NoError(t, err) + require.Len(t, attachments, 1) + + require.Equal(t, "file 1.txt", attachments[0].Title) + require.Equal(t, course.ID, attachments[0].CourseID) + require.Equal(t, filepath.Join(course.Path, "01 file 1.txt"), attachments[0].Path) + + // Add another attachment + afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/01 file 2.txt", course.Path), []byte("file 2"), os.ModePerm) + + err = Processor(ctx, scanner, scan) + require.NoError(t, err) + + err = scanner.dao.List(ctx, &attachments, attachmentOptions) + require.NoError(t, err) + require.Len(t, attachments, 2) + + require.Equal(t, "file 1.txt", attachments[0].Title) + require.Equal(t, course.ID, attachments[0].CourseID) + require.Equal(t, filepath.Join(course.Path, "01 file 1.txt"), attachments[0].Path) + + require.Equal(t, "file 2.txt", attachments[1].Title) + require.Equal(t, course.ID, attachments[1].CourseID) + require.Equal(t, filepath.Join(course.Path, "01 file 2.txt"), attachments[1].Path) + + // Delete attachment + scanner.appFs.Fs.Remove(fmt.Sprintf("%s/01 file 1.txt", course.Path)) + + err = Processor(ctx, scanner, scan) + require.NoError(t, err) + + err = scanner.dao.List(ctx, &attachments, attachmentOptions) + require.NoError(t, err) + require.Len(t, attachments, 1) + + require.Equal(t, "file 2.txt", attachments[0].Title) + require.Equal(t, course.ID, attachments[0].CourseID) + require.Equal(t, filepath.Join(course.Path, "01 file 2.txt"), attachments[0].Path) + }) + + t.Run("asset priority", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + // ---------------------------- + // Priority is VIDEO -> HTML -> PDF + // ---------------------------- + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + scan := &models.Scan{CourseID: course.ID, Status: types.NewScanStatusWaiting()} + require.NoError(t, scanner.dao.CreateScan(ctx, scan)) + + // Add PDF asset + scanner.appFs.Fs.Mkdir(course.Path, os.ModePerm) + afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/01 doc 1.pdf", course.Path), []byte("doc 1"), os.ModePerm) + + err := Processor(ctx, scanner, scan) + require.NoError(t, err) + + assetOptions := &database.Options{ + OrderBy: []string{models.ASSET_TABLE + ".chapter asc", models.ASSET_TABLE + ".prefix asc"}, + Where: squirrel.Eq{models.ASSET_TABLE + ".course_id": course.ID}, + } + + assets := []*models.Asset{} + err = scanner.dao.List(ctx, &assets, assetOptions) + require.NoError(t, err) + require.Len(t, assets, 1) + + require.Equal(t, "doc 1", assets[0].Title) + require.Equal(t, course.ID, assets[0].CourseID) + require.Equal(t, 1, int(assets[0].Prefix.Int16)) + require.Empty(t, assets[0].Chapter) + require.True(t, assets[0].Type.IsPDF()) + require.Equal(t, "61363a1cb5bf5514e3f9e983b6a96aeb12dd1ccff1b19938231d6b798d5832f9", assets[0].Hash) + require.Len(t, assets[0].Attachments, 0) + + // Add HTML asset + afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/01 index.html", course.Path), []byte("index"), os.ModePerm) + + err = Processor(ctx, scanner, scan) + require.NoError(t, err) + + err = scanner.dao.List(ctx, &assets, assetOptions) + require.NoError(t, err) + require.Len(t, assets, 1) + + require.Equal(t, "index", assets[0].Title) + require.Equal(t, course.ID, assets[0].CourseID) + require.Equal(t, 1, int(assets[0].Prefix.Int16)) + require.Empty(t, assets[0].Chapter) + require.True(t, assets[0].Type.IsHTML()) + require.Equal(t, "935600bd3714c889e3d03a3196cf0e90b4a6aa51af8a73f7867c8a421a1106ba", assets[0].Hash) + require.Len(t, assets[0].Attachments, 1) + + attachmentOptions := &database.Options{ + OrderBy: []string{models.ATTACHMENT_TABLE + ".created_at asc"}, + Where: squirrel.Eq{models.ATTACHMENT_TABLE + ".course_id": course.ID}, + } + + attachments := []*models.Attachment{} + err = scanner.dao.List(ctx, &attachments, attachmentOptions) + require.NoError(t, err) + require.Len(t, attachments, 1) + require.Equal(t, filepath.Join(course.Path, "01 doc 1.pdf"), attachments[0].Path) + + // Add VIDEO asset + afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/01 video.mp4", course.Path), []byte("video"), os.ModePerm) + + err = Processor(ctx, scanner, scan) + require.NoError(t, err) + + err = scanner.dao.List(ctx, &assets, assetOptions) + require.NoError(t, err) + require.Len(t, assets, 1) + + require.Equal(t, "video", assets[0].Title) + require.Equal(t, course.ID, assets[0].CourseID) + require.Equal(t, 1, int(assets[0].Prefix.Int16)) + require.Empty(t, assets[0].Chapter) + require.True(t, assets[0].Type.IsVideo()) + require.Equal(t, "e56ca866bff1691433766c60304a96583c1a410e53b33ef7d89cb29eac2a97ab", assets[0].Hash) + require.Len(t, assets[0].Attachments, 2) + + err = scanner.dao.List(ctx, &attachments, attachmentOptions) + require.NoError(t, err) + require.Len(t, attachments, 2) + require.Equal(t, filepath.Join(course.Path, "01 doc 1.pdf"), attachments[0].Path) + require.Equal(t, filepath.Join(course.Path, "01 index.html"), attachments[1].Path) + + // Add another PDF asset + afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/01 doc 2.pdf", course.Path), []byte("doc 2"), os.ModePerm) + + err = Processor(ctx, scanner, scan) + require.NoError(t, err) + + err = scanner.dao.List(ctx, &assets, assetOptions) + require.NoError(t, err) + require.Len(t, assets, 1) + + require.Equal(t, "video", assets[0].Title) + require.Equal(t, course.ID, assets[0].CourseID) + require.Equal(t, 1, int(assets[0].Prefix.Int16)) + require.Empty(t, assets[0].Chapter) + require.True(t, assets[0].Type.IsVideo()) + require.Equal(t, "e56ca866bff1691433766c60304a96583c1a410e53b33ef7d89cb29eac2a97ab", assets[0].Hash) + require.Len(t, assets[0].Attachments, 3) + + err = scanner.dao.List(ctx, &attachments, attachmentOptions) + require.NoError(t, err) + require.Len(t, attachments, 3) + require.Equal(t, filepath.Join(course.Path, "01 doc 1.pdf"), attachments[0].Path) + require.Equal(t, filepath.Join(course.Path, "01 index.html"), attachments[1].Path) + require.Equal(t, filepath.Join(course.Path, "01 doc 2.pdf"), attachments[2].Path) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestScanner_parseFilename(t *testing.T) { + t.Run("invalid", func(t *testing.T) { + var tests = []string{ + // No prefix + "file", + "file.file", + "file.avi", + " - file.avi", + "- - file.avi", + ".avi", + // Invalid prefix + "-1 - file.avi", + "a - file.avi", + "1.1 - file.avi", + "2.3-file.avi", + "1file.avi", + } + + for _, tt := range tests { + fb := parseFilename(tt) + require.Nil(t, fb) + } + }) + + t.Run("assets", func(t *testing.T) { + var tests = []struct { + in string + expected *parsedFilename + }{ + // Video (with varied filenames) + {"0 file 0.avi", &parsedFilename{prefix: 0, title: "file 0", asset: types.NewAsset("avi")}}, + {"001 file 1.mp4", &parsedFilename{prefix: 1, title: "file 1", asset: types.NewAsset("mp4")}}, + {"1-file.ogg", &parsedFilename{prefix: 1, title: "file", asset: types.NewAsset("ogg")}}, + {"2 - file.webm", &parsedFilename{prefix: 2, title: "file", asset: types.NewAsset("webm")}}, + {"3 -file.m4a", &parsedFilename{prefix: 3, title: "file", asset: types.NewAsset("m4a")}}, + {"4- file.opus", &parsedFilename{prefix: 4, title: "file", asset: types.NewAsset("opus")}}, + {"5000 --- file.wav", &parsedFilename{prefix: 5000, title: "file", asset: types.NewAsset("wav")}}, + {"0100 file.mp3", &parsedFilename{prefix: 100, title: "file", asset: types.NewAsset("mp3")}}, + // PDF + {"1 - doc.pdf", &parsedFilename{prefix: 1, title: "doc", asset: types.NewAsset("pdf")}}, + // HTML + {"1 index.html", &parsedFilename{prefix: 1, title: "index", asset: types.NewAsset("html")}}, + } + + for _, tt := range tests { + fb := parseFilename(tt.in) + require.Equal(t, tt.expected, fb, fmt.Sprintf("error for [%s]", tt.in)) + } + }) + + t.Run("attachments", func(t *testing.T) { + var tests = []struct { + in string + expected *parsedFilename + }{ + // No title + {"01", &parsedFilename{prefix: 1, title: "01"}}, + {"200.pdf", &parsedFilename{prefix: 200, title: "200.pdf"}}, + {"1 -.txt", &parsedFilename{prefix: 1, title: "1 -.txt"}}, + {"1 .txt", &parsedFilename{prefix: 1, title: "1 .txt"}}, + {"1 .pdf", &parsedFilename{prefix: 1, title: "1 .pdf"}}, + // No extension (fileName should have no prefix) + {"0 file 0", &parsedFilename{prefix: 0, title: "file 0"}}, + {"001 file 1", &parsedFilename{prefix: 1, title: "file 1"}}, + {"1001 - file", &parsedFilename{prefix: 1001, title: "file"}}, + {"0123-file", &parsedFilename{prefix: 123, title: "file"}}, + {"1 --- file", &parsedFilename{prefix: 1, title: "file"}}, + // Non-asset extension (fileName should have no prefix) + {"1 file.txt", &parsedFilename{prefix: 1, title: "file.txt"}}, + } + + for _, tt := range tests { + fb := parseFilename(tt.in) + require.Equal(t, tt.expected, fb, fmt.Sprintf("error for [%s]", tt.in)) + } + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestScanner_IsCard(t *testing.T) { + t.Run("invalid", func(t *testing.T) { + var tests = []string{ + "card", + "1234", + "1234.jpg", + "jpg", + "card.test.jpg", + "card.txt", + } + + for _, tt := range tests { + require.False(t, isCard(tt)) + } + }) + + t.Run("valid", func(t *testing.T) { + var tests = []string{ + "card.jpg", + "card.jpeg", + "card.png", + "card.webp", + "card.tiff", + } + + for _, tt := range tests { + require.True(t, isCard(tt)) + } + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestScanner_UpdateAssets(t *testing.T) { + t.Run("nothing added or deleted", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + numAssets := 10 + batchSize := 3 + + assets := []*models.Asset{} + for i := range numAssets { + asset := &models.Asset{ + CourseID: course.ID, + Title: fmt.Sprintf("Asset %d", i+1), + Prefix: sql.NullInt16{Int16: int16(i + 1), Valid: true}, + Chapter: fmt.Sprintf("Chapter %d", i/batchSize+1), + Type: *types.NewAsset("mp4"), + Path: fmt.Sprintf("/course-1/Chapter %d/%d asset.mp4", i/batchSize+1, (i%batchSize)+1), + Hash: security.RandomString(64), + } + require.NoError(t, scanner.dao.CreateAsset(ctx, asset)) + assets = append(assets, asset) + } + + err := updateAssets(ctx, scanner.dao, course.ID, assets) + require.NoError(t, err) + + count, err := scanner.dao.Count(ctx, &models.Asset{}, nil) + require.NoError(t, err) + require.Equal(t, numAssets, count) + }) + + t.Run("add", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + numAssets := 10 + batchSize := 3 + + assets := []*models.Asset{} + for i := range numAssets { + asset := &models.Asset{ + CourseID: course.ID, + Title: fmt.Sprintf("Asset %d", i+1), + Prefix: sql.NullInt16{Int16: int16(i + 1), Valid: true}, + Chapter: fmt.Sprintf("Chapter %d", i/batchSize+1), + Type: *types.NewAsset("mp4"), + Path: fmt.Sprintf("/course-1/Chapter %d/%d asset.mp4", i/batchSize+1, (i%batchSize)+1), + Hash: security.RandomString(64), + } + assets = append(assets, asset) + } + + // Add first 7 assets + err := updateAssets(ctx, scanner.dao, course.ID, assets[:7]) + require.NoError(t, err) + + count, err := scanner.dao.Count(ctx, &models.Asset{}, nil) + require.NoError(t, err) + require.Equal(t, 7, count) + + // Add remaining assets + err = updateAssets(ctx, scanner.dao, course.ID, assets) + require.NoError(t, err) + + count, err = scanner.dao.Count(ctx, &models.Asset{}, nil) + require.NoError(t, err) + require.Equal(t, numAssets, count) + + // Ensure all assets have an ID + assetsResult := []*models.Asset{} + err = scanner.dao.List(ctx, &assetsResult, nil) + require.NoError(t, err) + }) + + t.Run("delete", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + numAssets := 10 + batchSize := 3 + + assets := []*models.Asset{} + for i := range numAssets { + asset := &models.Asset{ + CourseID: course.ID, + Title: fmt.Sprintf("Asset %d", i+1), + Prefix: sql.NullInt16{Int16: int16(i + 1), Valid: true}, + Chapter: fmt.Sprintf("Chapter %d", i/batchSize+1), + Type: *types.NewAsset("mp4"), + Path: fmt.Sprintf("/course-1/Chapter %d/%d asset.mp4", i/batchSize+1, (i%batchSize)+1), + Hash: security.RandomString(64), + } + require.NoError(t, scanner.dao.CreateAsset(ctx, asset)) + assets = append(assets, asset) + } + + // Delete 2 assets + err := updateAssets(ctx, scanner.dao, course.ID, assets[2:]) + require.NoError(t, err) + + count, err := scanner.dao.Count(ctx, &models.Asset{}, nil) + require.NoError(t, err) + require.Equal(t, 8, count) + + // Delete another 2 assets + err = updateAssets(ctx, scanner.dao, course.ID, assets[4:]) + require.NoError(t, err) + + count, err = scanner.dao.Count(ctx, &models.Asset{}, nil) + require.NoError(t, err) + require.Equal(t, 6, count) + }) + + t.Run("rename", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + numAssets := 10 + batchSize := 3 + + assets := []*models.Asset{} + for i := range numAssets { + asset := &models.Asset{ + CourseID: course.ID, + Title: fmt.Sprintf("Asset %d", i+1), + Prefix: sql.NullInt16{Int16: int16(i + 1), Valid: true}, + Chapter: fmt.Sprintf("Chapter %d", i/batchSize+1), + Type: *types.NewAsset("mp4"), + Path: fmt.Sprintf("/course-1/Chapter %d/%d asset.mp4", i/batchSize+1, (i%batchSize)+1), + Hash: security.RandomString(64), + } + assets = append(assets, asset) + } + + // Add assets + err := updateAssets(ctx, scanner.dao, course.ID, assets) + require.NoError(t, err) + + count, err := scanner.dao.Count(ctx, &models.Asset{}, nil) + require.NoError(t, err) + require.Equal(t, numAssets, count) + + // Rename 2 assets + assets[2].Title = "Asset 100" + assets[2].Prefix = sql.NullInt16{Int16: 100, Valid: true} + assets[2].Chapter = "Chapter 100" + assets[2].Path = "/course-1/Chapter 100/100 asset.mp4" + + assets[4].Title = "Asset 200" + assets[4].Prefix = sql.NullInt16{Int16: 200, Valid: true} + assets[4].Chapter = "Chapter 200" + assets[4].Path = "/course-1/Chapter 200/200 asset.mp4" + + err = updateAssets(ctx, scanner.dao, course.ID, assets) + require.NoError(t, err) + + count, err = scanner.dao.Count(ctx, &models.Asset{}, nil) + require.NoError(t, err) + require.Equal(t, numAssets, count) + + // Ensure the assets were updated + asset2 := &models.Asset{Base: models.Base{ID: assets[2].ID}} + err = scanner.dao.GetById(ctx, asset2) + require.NoError(t, err) + require.Equal(t, "Asset 100", asset2.Title) + require.Equal(t, int16(100), asset2.Prefix.Int16) + require.Equal(t, "Chapter 100", asset2.Chapter) + require.Equal(t, "/course-1/Chapter 100/100 asset.mp4", asset2.Path) + + asset4 := &models.Asset{Base: models.Base{ID: assets[4].ID}} + err = scanner.dao.GetById(ctx, asset4) + require.NoError(t, err) + require.Equal(t, "Asset 200", asset4.Title) + require.Equal(t, int16(200), asset4.Prefix.Int16) + require.Equal(t, "Chapter 200", asset4.Chapter) + require.Equal(t, "/course-1/Chapter 200/200 asset.mp4", asset4.Path) + }) + + t.Run("swap", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + numAssets := 2 + + assets := []*models.Asset{} + for i := range numAssets { + asset := &models.Asset{ + CourseID: course.ID, + Title: fmt.Sprintf("Asset %d", i+1), + Prefix: sql.NullInt16{Int16: int16(i + 1), Valid: true}, + Chapter: fmt.Sprintf("Chapter %d", i+1), + Type: *types.NewAsset("mp4"), + Path: fmt.Sprintf("/course-1/Chapter %d/%d asset.mp4", i+1, i+1), + Hash: security.RandomString(64), + } + + require.NoError(t, scanner.dao.CreateAsset(ctx, asset)) + assets = append(assets, asset) + } + + // Swap assets title and path + assets[0].Title, assets[1].Title = assets[1].Title, assets[0].Title + assets[0].Path, assets[1].Path = assets[1].Path, assets[0].Path + + err := updateAssets(ctx, scanner.dao, course.ID, assets) + require.NoError(t, err) + + asset1 := &models.Asset{Base: models.Base{ID: assets[0].ID}} + err = scanner.dao.GetById(ctx, asset1) + require.NoError(t, err) + require.Equal(t, "Asset 2", asset1.Title) + require.Equal(t, "/course-1/Chapter 2/2 asset.mp4", asset1.Path) + + asset2 := &models.Asset{Base: models.Base{ID: assets[1].ID}} + err = scanner.dao.GetById(ctx, asset2) + require.NoError(t, err) + require.Equal(t, "Asset 1", asset2.Title) + require.Equal(t, "/course-1/Chapter 1/1 asset.mp4", asset2.Path) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestScanner_UpdateAttachments(t *testing.T) { + t.Run("nothing added or deleted", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/Chapter 1/1 Asset 1.mp4", + Hash: security.RandomString(64), + } + require.NoError(t, scanner.dao.CreateAsset(ctx, asset)) + + numAttachments := 10 + + attachments := []*models.Attachment{} + for i := range numAttachments { + attachment := &models.Attachment{ + CourseID: course.ID, + AssetID: asset.ID, + Title: fmt.Sprintf("Attachment %d", i+1), + Path: fmt.Sprintf("/course-1/Chapter 1/1 Attachment %d.pdf", i+1), + } + require.NoError(t, scanner.dao.CreateAttachment(ctx, attachment)) + attachments = append(attachments, attachment) + } + + err := updateAttachments(ctx, scanner.dao, course.ID, attachments) + require.NoError(t, err) + + count, err := scanner.dao.Count(ctx, &models.Attachment{}, nil) + require.NoError(t, err) + require.Equal(t, numAttachments, count) + }) + + t.Run("add", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/Chapter 1/1 Asset 1.mp4", + Hash: security.RandomString(64), + } + require.NoError(t, scanner.dao.CreateAsset(ctx, asset)) + + numAttachments := 10 + + attachments := []*models.Attachment{} + for i := range numAttachments { + attachment := &models.Attachment{ + CourseID: course.ID, + AssetID: asset.ID, + Title: fmt.Sprintf("Attachment %d", i+1), + Path: fmt.Sprintf("/course-1/Chapter 1/1 Attachment %d.pdf", i+1), + } + attachments = append(attachments, attachment) + } + + // Add first 7 attachments + err := updateAttachments(ctx, scanner.dao, course.ID, attachments[:7]) + require.NoError(t, err) + + count, err := scanner.dao.Count(ctx, &models.Attachment{}, nil) + require.NoError(t, err) + require.Equal(t, 7, count) + + // Add remaining attachments + err = updateAttachments(ctx, scanner.dao, course.ID, attachments) + require.NoError(t, err) + + count, err = scanner.dao.Count(ctx, &models.Attachment{}, nil) + require.NoError(t, err) + require.Equal(t, numAttachments, count) + }) + + t.Run("delete", func(t *testing.T) { + scanner, ctx, _ := setup(t) + + course := &models.Course{Title: "Course 1", Path: "/course-1"} + require.NoError(t, scanner.dao.CreateCourse(ctx, course)) + + asset := &models.Asset{ + CourseID: course.ID, + Title: "Asset 1", + Prefix: sql.NullInt16{Int16: 1, Valid: true}, + Chapter: "Chapter 1", + Type: *types.NewAsset("mp4"), + Path: "/course-1/Chapter 1/1 Asset 1.mp4", + Hash: security.RandomString(64), + } + require.NoError(t, scanner.dao.CreateAsset(ctx, asset)) + + numAttachments := 10 + + attachments := []*models.Attachment{} + for i := range numAttachments { + attachment := &models.Attachment{ + CourseID: course.ID, + AssetID: asset.ID, + Title: fmt.Sprintf("Attachment %d", i+1), + Path: fmt.Sprintf("/course-1/Chapter 1/1 Attachment %d.pdf", i+1), + } + require.NoError(t, scanner.dao.CreateAttachment(ctx, attachment)) + attachments = append(attachments, attachment) + } + + // Delete 2 attachments + err := updateAttachments(ctx, scanner.dao, course.ID, attachments[2:]) + require.NoError(t, err) + + count, err := scanner.dao.Count(ctx, &models.Attachment{}, nil) + require.NoError(t, err) + require.Equal(t, 8, count) + + // Delete another 2 attachments + err = updateAttachments(ctx, scanner.dao, course.ID, attachments[4:]) + require.NoError(t, err) + + count, err = scanner.dao.Count(ctx, &models.Attachment{}, nil) + require.NoError(t, err) + require.Equal(t, 6, count) + }) +} diff --git a/utils/coursescan/errors.go b/utils/coursescan/errors.go new file mode 100644 index 0000000..74c6ad6 --- /dev/null +++ b/utils/coursescan/errors.go @@ -0,0 +1,7 @@ +package coursescan + +import "errors" + +var ( + ErrNilScan = errors.New("scan cannot be empty") +) diff --git a/utils/errors.go b/utils/errors.go new file mode 100644 index 0000000..a6ea811 --- /dev/null +++ b/utils/errors.go @@ -0,0 +1,24 @@ +package utils + +import "errors" + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +var ( + // Generic + ErrNilPtr = errors.New("nil pointer") + ErrNotPtr = errors.New("requires a pointer") + ErrNotModeler = errors.New("does not implement Modeler interface") + ErrEmbedded = errors.New("embedded struct does not implement Definer interface") + ErrInvalidValue = errors.New("invalid value") + ErrNotStruct = errors.New("not a struct") + ErrNotSlice = errors.New("not a slice") + ErrNoTable = errors.New("table name cannot be empty") + + // DB + ErrInvalidWhere = errors.New("where clause cannot be empty") + + // Model + ErrInvalidId = errors.New("id cannot be empty") + ErrInvalidKey = errors.New("key cannot be empty") +) diff --git a/utils/jobs/course_scanner.go b/utils/jobs/course_scanner.go deleted file mode 100644 index 45a86a9..0000000 --- a/utils/jobs/course_scanner.go +++ /dev/null @@ -1,704 +0,0 @@ -package jobs - -import ( - "database/sql" - "errors" - "log/slog" - "os" - "path/filepath" - "regexp" - "strconv" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/daos" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils" - "github.com/geerew/off-course/utils/appFs" - "github.com/geerew/off-course/utils/security" - "github.com/geerew/off-course/utils/types" -) - -var ( - loggerType = slog.Any("type", types.LogTypeCourseScanner) -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// CourseScannerProcessorFn is a function that processes a course scan job -type CourseScannerProcessorFn func(*CourseScanner, *models.Scan) error - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// CourseScanner scans a course and finds assets and attachments -type CourseScanner struct { - appFs *appFs.AppFs - db database.Database - logger *slog.Logger - - jobSignal chan bool - - // Required DAOs - courseDao *daos.CourseDao - scanDao *daos.ScanDao - assetDao *daos.AssetDao - attachmentDao *daos.AttachmentDao -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// CourseScannerConfig is the config for a CourseScanner -type CourseScannerConfig struct { - Db database.Database - AppFs *appFs.AppFs - Logger *slog.Logger -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// NewCourseScanner creates a new CourseScanner -func NewCourseScanner(config *CourseScannerConfig) *CourseScanner { - return &CourseScanner{ - appFs: config.AppFs, - db: config.Db, - logger: config.Logger, - jobSignal: make(chan bool, 1), - courseDao: daos.NewCourseDao(config.Db), - scanDao: daos.NewScanDao(config.Db), - assetDao: daos.NewAssetDao(config.Db), - attachmentDao: daos.NewAttachmentDao(config.Db), - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Add inserts a course scan job into the db -func (cs *CourseScanner) Add(courseId string) (*models.Scan, error) { - // Check if the course exists - course, err := cs.courseDao.Get(courseId, nil, nil) - if err != nil { - return nil, err - } - - // Do nothing when a scan job is already in progress - if course.ScanStatus != "" { - cs.logger.Debug( - "Scan already in progress", - loggerType, - slog.String("path", course.Path), - ) - - // Get the scan from the db and return that - scan, err := cs.scanDao.Get(courseId, nil) - if err != nil { - return nil, err - } - - return scan, nil - } - - // Add the job - scan := &models.Scan{CourseID: courseId, Status: types.NewScanStatus(types.ScanStatusWaiting)} - if err := cs.scanDao.Create(scan, nil); err != nil { - return nil, err - } - - // Signal the worker to process the job - select { - case cs.jobSignal <- true: - default: - } - - cs.logger.Info( - "Added scan job", - loggerType, - slog.String("path", course.Path), - ) - - return scan, nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// Worker processes jobs out of the DB sequentially -func (cs *CourseScanner) Worker(processor CourseScannerProcessorFn, processingDone chan bool) { - cs.logger.Debug("Started course scanner worker", loggerType) - - for { - <-cs.jobSignal - for { - // Get the next scan - job, err := cs.scanDao.Next(nil) - if err != nil { - cs.logger.Error( - "Failed to look up next scan job", - loggerType, - slog.String("error", err.Error()), - ) - - if processingDone != nil { - processingDone <- true - } - - break - } - - if job == nil { - if processingDone != nil { - processingDone <- true - } - cs.logger.Debug("Finished processing all jobs", loggerType) - break - } - - cs.logger.Info( - "Processing scan job", - loggerType, - slog.String("job", job.ID), - slog.String("path", job.CoursePath), - ) - - // Process the job - err = processor(cs, job) - if err != nil { - cs.logger.Error( - "Failed to process scan job", - loggerType, - slog.String("error", err.Error()), - slog.String("path", job.CoursePath), - ) - } - - // Cleanup - if err := cs.scanDao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": job.ID}}, nil); err != nil { - cs.logger.Error( - "Failed to delete scan job", - loggerType, - slog.String("error", err.Error()), - slog.String("job", job.ID), - ) - - if processingDone != nil { - processingDone <- true - } - - break - } - - } - } -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// assetMap and attachmentMap are maps to hold encountered assets and attachments -type assetMap map[string]map[int]*models.Asset -type attachmentMap map[string]map[int][]*models.Attachment - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// CourseProcessor scans a course and finds assets and attachments -func CourseProcessor(cs *CourseScanner, scan *models.Scan) error { - if scan == nil { - return errors.New("scan cannot be empty") - } - - // Get the course for this scan - course, err := cs.courseDao.Get(scan.CourseID, nil, nil) - if err != nil { - // If the course does not exist, we can ignore this job - if err == sql.ErrNoRows { - cs.logger.Debug( - "Ignoring scan job as the course no longer exists", - loggerType, - slog.String("path", scan.CoursePath), - ) - - return nil - } - - return err - } - - // Check the availability of the course. When a course is unavailable, we do not want to scan - // it. This prevents assets and attachments from being deleted unintentionally - _, err = cs.appFs.Fs.Stat(course.Path) - if err != nil { - if os.IsNotExist(err) { - cs.logger.Debug( - "Ignoring scan job as the course path does not exist", - loggerType, - slog.String("path", scan.CoursePath), - ) - - if course.Available { - course.Available = false - err = cs.courseDao.Update(course, nil) - if err != nil { - return err - } - } - - return nil - } - - return err - } - - // If the course is currently marked as unavailable, set it as available - if !course.Available { - course.Available = true - err := cs.courseDao.Update(course, nil) - if err != nil { - return err - } - } - - // Set the scan status to processing - scan.Status = types.NewScanStatus(types.ScanStatusProcessing) - err = cs.scanDao.Update(scan, nil) - if err != nil { - return err - } - - cardPath := "" - - // Get all files down to a depth of 2 (immediate files and and files within 'chapters') - files, err := cs.appFs.ReadDirFlat(course.Path, 2) - if err != nil { - return err - } - - // Maps to hold encountered assets and attachments by [chapter][prefix] - assetsMap := assetMap{} - attachmentsMap := attachmentMap{} - - for _, fp := range files { - normalizedFilePath := utils.NormalizeWindowsDrive(fp) - // Get the fileName from the path (ex /path/to/file.txt -> file.txt) - fileName := filepath.Base(normalizedFilePath) - - // Get the fileDir from the path (ex /path/to/file.txt -> /path/to) - fileDir := filepath.Dir(normalizedFilePath) - isRootDir := fileDir == utils.NormalizeWindowsDrive(course.Path) - - // Check if this file is a card. Only check when not yet set and the file exists at the - // course root - if cardPath == "" && isRootDir { - if isCard(fileName) { - cardPath = normalizedFilePath - continue - } - } - - // Get the chapter for this file. This will be empty when the file exists at the course - // root - chapter := "" - if !isRootDir { - chapter = filepath.Base(fileDir) - } - - // Create a new map entry for this chapter - if _, exists := assetsMap[chapter]; !exists { - assetsMap[chapter] = make(map[int]*models.Asset) - } - - // Add chapter to attachments map - if _, exists := attachmentsMap[chapter]; !exists { - attachmentsMap[chapter] = make(map[int][]*models.Attachment) - } - - // Parse the file name to see if it is an asset, attachment or neither - pfn := parseFileName(fileName) - - if pfn == nil { - cs.logger.Debug( - "Ignoring file during scan job", - loggerType, - slog.String("path", scan.CoursePath), - slog.String("file", normalizedFilePath), - ) - - continue - } - - if pfn.asset != nil { - // Check if we have an existing asset for this [chapter][prefix] - existing, exists := assetsMap[chapter][pfn.prefix] - - // Get a (partial) hash of the asset - hash, err := cs.appFs.PartialHash(normalizedFilePath, 1024*1024) - if err != nil { - return err - } - - newAsset := &models.Asset{ - Title: pfn.title, - Prefix: sql.NullInt16{Int16: int16(pfn.prefix), Valid: true}, - CourseID: course.ID, - Chapter: chapter, - Path: normalizedFilePath, - Type: *pfn.asset, - Hash: hash, - } - - if !exists { - // New asset - assetsMap[chapter][pfn.prefix] = newAsset - } else { - // Found an existing asset. Check if this new asset has a higher priority than the - // existing asset. The priority is video > html > pdf - if newAsset.Type.IsVideo() && !existing.Type.IsVideo() || - newAsset.Type.IsHTML() && existing.Type.IsPDF() { - // Asset -> Replace the existing asset with the new asset and set the existing - // asset as an attachment - cs.logger.Debug( - "Replacing existing asset with new asset", - loggerType, - slog.String("path", scan.CoursePath), - slog.String("file", normalizedFilePath), - ) - - assetsMap[chapter][pfn.prefix] = newAsset - - attachmentsMap[chapter][pfn.prefix] = append(attachmentsMap[chapter][pfn.prefix], &models.Attachment{ - Title: existing.Title + filepath.Ext(existing.Path), - Path: existing.Path, - CourseID: course.ID, - }) - } else { - // Attachment -> This new asset has a lower priority than the existing asset - attachmentsMap[chapter][pfn.prefix] = append(attachmentsMap[chapter][pfn.prefix], &models.Attachment{ - Title: pfn.attachmentTitle, - Path: normalizedFilePath, - CourseID: course.ID, - }) - } - } - } else { - // Attachment - attachmentsMap[chapter][pfn.prefix] = append(attachmentsMap[chapter][pfn.prefix], &models.Attachment{ - Title: pfn.attachmentTitle, - Path: normalizedFilePath, - CourseID: course.ID, - }) - } - } - - course.CardPath = cardPath - - // Run in a transaction so it all commits, or it rolls back - err = cs.db.RunInTransaction(func(tx *database.Tx) error { - // Convert the assets map to a slice - assets := make([]*models.Asset, 0, len(files)) - for _, chapterMap := range assetsMap { - for _, asset := range chapterMap { - assets = append(assets, asset) - } - } - - // Update the assets in DB - if len(assets) > 0 { - err = updateAssets(cs.assetDao, tx, course.ID, assets) - if err != nil { - return err - } - } - - // Convert the attachments map to a slice - attachments := []*models.Attachment{} - for chapter, attachmentMap := range attachmentsMap { - for prefix, potentialAttachments := range attachmentMap { - // Only add attachments when there is an assert - if asset, exists := assetsMap[chapter][prefix]; exists { - for _, attachment := range potentialAttachments { - attachment.AssetID = asset.ID - attachments = append(attachments, attachment) - } - } - } - } - - // Update the attachments in DB - if len(attachments) > 0 { - err = updateAttachments(cs.attachmentDao, tx, course.ID, attachments) - if err != nil { - return err - } - } - - // Update the course (card_path, updated_at) - if err = cs.courseDao.Update(course, tx); err != nil { - return err - } - - return nil - }) - - return err -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// PRIVATE -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// parsedFileName that holds information following a filename being parsed -type parsedFileName struct { - prefix int - title string - ext string - attachmentTitle string - asset *types.Asset -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// A regex for parsing a file name into a prefix, title, and extension -// -// Valid patterns: -// -// `<prefix>` -// `<prefix>.<ext>` -// `<prefix> <title>` -// `<prefix>-<title>` -// `<prefix> - <title>` -// `<prefix> <title>.<ext>` -// `<prefix>-<title>.<ext>` -// `<prefix> - <title>.<ext>` -// -// - <prefix> is required and must be a number -// - A dash (-) is optional -// - <title> is optional and can be any non-empty string -// - <ext> is optional -var fileNameRegex = regexp.MustCompile(`^\s*(?P<Prefix>[0-9]+)((?:\s+-+\s+|\s+-+|\s+|-+\s*)(?P<Title>[^.][^.]*)?)?(?:\.(?P<Ext>\w+))?$`) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// parseFileName parses a file name and determines if it represents an asset, attachment, or neither -// -// A file is an asset when it matches `<prefix> <title>.<ext>` and <ext> is a valid `types.AssetType` -// -// A file is an attachment when it has a <prefix>, and optionally a <title> and/or <ext>, whereby <ext> -// is not a valid `types.AssetType` -// -// Parameters: -// - fileName: The name of the file to parse -// -// Returns: -// - *parsedFileName: A struct containing parsed information if the file name matches the expected -// format, otherwise nil -func parseFileName(fileName string) *parsedFileName { - pfn := &parsedFileName{} - - // Match the file name against the regex and ignore if no match - matches := fileNameRegex.FindStringSubmatch(fileName) - if len(matches) == 0 { - return nil - } - - // Convert the prefix to an int and ignore if missing or not a number - prefix, err := strconv.Atoi(matches[fileNameRegex.SubexpIndex("Prefix")]) - if err != nil { - return nil - } - - pfn.prefix = prefix - pfn.title = matches[fileNameRegex.SubexpIndex("Title")] - - // Check the title. When empty, this is an attachment - if pfn.title == "" { - pfn.attachmentTitle = fileName - return pfn - } - - // Check the extension. When empty, this is an attachment - pfn.ext = matches[fileNameRegex.SubexpIndex("Ext")] - if pfn.ext == "" { - pfn.attachmentTitle = pfn.title - return pfn - } - - // Set the attachment title, in the event that this is an attachment - pfn.attachmentTitle = pfn.title + "." + pfn.ext - - // Check if this is a valid asset - pfn.asset = types.NewAsset(pfn.ext) - - return pfn -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// isCard determines if a given file name represents a card based on its name and extension -// -// Parameters: -// - fileName: The name of the file to check -// -// Returns: -// - bool: True if the file name represents a card, false otherwise -func isCard(fileName string) bool { - // Get the extension. If there is no extension, return false - ext := filepath.Ext(fileName) - if ext == "" { - return false - } - - fileWithoutExt := fileName[:len(fileName)-len(ext)] - if fileWithoutExt != "card" { - return false - } - - // Check if the extension is supported - switch ext[1:] { - case - "jpg", - "jpeg", - "png", - "webp", - "tiff": - return true - } - - return false -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// updateAssets updates the assets in the database based on the assets found on disk. It compares -// the existing assets in the database with the assets found on disk, and performs the necessary -// additions, deletions, and updates -// -// Parameters: -// - assetDao: The DAO used to interact with the assets table in the database -// - tx: The database transaction within which all operations should be performed -// - courseId: The ID of the course to which the assets belong -// - assets: A slice of Asset structs representing the assets found on disk -// -// Returns: -// - error: An error if any operation fails, otherwise nil -func updateAssets(assetDao *daos.AssetDao, tx *database.Tx, courseId string, assets []*models.Asset) error { - // Get existing assets - dbParams := &database.DatabaseParams{ - Where: squirrel.Eq{assetDao.Table() + ".course_id": courseId}, - } - - existingAssets, err := assetDao.List(dbParams, tx) - if err != nil { - return err - } - - // Compare the assets found on disk to assets found in DB and identify which assets to add and - // which assets to delete - toAdd, toDelete, err := utils.DiffSliceOfStructsByKey(assets, existingAssets, "Hash") - if err != nil { - return err - } - - // Add assets - // TODO: This could be optimized by using a bulk insert - for _, asset := range toAdd { - if err := assetDao.Create(asset, tx); err != nil { - return err - } - } - - // Bulk delete assets - whereClause := squirrel.Or{} - for _, deleteAsset := range toDelete { - whereClause = append(whereClause, squirrel.Eq{"id": deleteAsset.ID}) - } - - if err := assetDao.Delete(&database.DatabaseParams{Where: whereClause}, tx); err != nil { - return err - } - - // Identify the existing assets whose information has changed - existingAssetsMap := make(map[string]*models.Asset) - for _, existingAsset := range existingAssets { - existingAssetsMap[existingAsset.Hash] = existingAsset - } - - randomTempSuffix := security.RandomString(10) - updatedAssets := make([]*models.Asset, 0, len(assets)) - - // On the first pass we update the existing assets with details of the new asset. In addition, we - // set the path to be path+randomTempSuffix. This is to prevent a `unique path constraint` error if, - // for example, 2 files are have their titles swapped. - // - // On the second pass we update the existing assets and remove the randomTempSuffix from the path - for _, asset := range assets { - if existingAsset, exists := existingAssetsMap[asset.Hash]; exists { - asset.ID = existingAsset.ID - - if !utils.CompareStructs(asset, existingAsset, []string{"CreatedAt", "UpdatedAt"}) { - asset.Path = asset.Path + randomTempSuffix - updatedAssets = append(updatedAssets, asset) - - // The assets has been updated to have the existing assets ID, so this will update the - // existing asset with the details of the new asset - if err := assetDao.Update(asset, tx); err != nil { - return err - } - } - } - } - - for _, asset := range updatedAssets { - asset.Path = asset.Path[:len(asset.Path)-len(randomTempSuffix)] - - if err := assetDao.Update(asset, tx); err != nil { - return err - } - } - - return nil -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// updateAttachments updates the attachments in the database based on the attachments found on disk. -// It compares the existing attachments in the database with the attachments found on disk, and performs -// the necessary additions and deletions -// -// Parameters: -// - attachmentDao: The DAO used to interact with the attachments table in the database -// - tx: The database transaction within which all operations should be performed -// - courseId: The ID of the course to which the attachments belong -// - attachments: A slice of Attachment structs representing the attachments found on disk -// -// Returns: -// - error: An error if any operation fails, otherwise nil -func updateAttachments(attachmentDao *daos.AttachmentDao, tx *database.Tx, courseId string, attachments []*models.Attachment) error { - // Get existing attachments - dbParams := &database.DatabaseParams{ - Where: squirrel.Eq{attachmentDao.Table() + ".course_id": courseId}, - } - - existingAttachments, err := attachmentDao.List(dbParams, tx) - if err != nil { - return err - } - - // Compare the attachments found on disk to attachments found in DB - toAdd, toDelete, err := utils.DiffSliceOfStructsByKey(attachments, existingAttachments, "Path") - if err != nil { - return err - } - - // Add the missing attachments - for _, attachment := range toAdd { - if err := attachmentDao.Create(attachment, tx); err != nil { - return err - } - } - - // Delete the irrelevant attachments - for _, attachment := range toDelete { - err := attachmentDao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": attachment.ID}}, tx) - if err != nil { - return err - } - } - - return nil -} diff --git a/utils/jobs/course_scanner_test.go b/utils/jobs/course_scanner_test.go deleted file mode 100644 index 6896b7a..0000000 --- a/utils/jobs/course_scanner_test.go +++ /dev/null @@ -1,1156 +0,0 @@ -package jobs - -import ( - "database/sql" - "errors" - "fmt" - "log/slog" - "math/rand" - "os" - "path/filepath" - "sync" - "testing" - - "github.com/Masterminds/squirrel" - "github.com/geerew/off-course/daos" - "github.com/geerew/off-course/database" - "github.com/geerew/off-course/models" - "github.com/geerew/off-course/utils/appFs" - "github.com/geerew/off-course/utils/logger" - "github.com/geerew/off-course/utils/security" - "github.com/geerew/off-course/utils/types" - "github.com/spf13/afero" - "github.com/stretchr/testify/require" -) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// setupCourseScanner initializes the CourseScanner and its dependencies -func setupCourseScanner(t *testing.T) (*CourseScanner, *database.DatabaseManager, *[]*logger.Log, *sync.Mutex) { - t.Helper() - - // Logger - var logs []*logger.Log - var logsMux sync.Mutex - logger, _, err := logger.InitLogger(&logger.BatchOptions{ - BatchSize: 1, - WriteFn: logger.TestWriteFn(&logs, &logsMux), - }) - require.NoError(t, err, "Failed to initialize logger") - - // Filesystem - appFs := appFs.NewAppFs(afero.NewMemMapFs(), logger) - - // Db - dbManager, err := database.NewSqliteDBManager(&database.DatabaseConfig{ - IsDebug: false, - DataDir: "./oc_data", - AppFs: appFs, - InMemory: true, - }) - - require.Nil(t, err) - require.NotNil(t, dbManager) - - // Course scanner - courseScanner := NewCourseScanner(&CourseScannerConfig{ - Db: dbManager.DataDb, - AppFs: appFs, - Logger: logger, - }) - - return courseScanner, dbManager, &logs, &logsMux -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseScanner_Add(t *testing.T) { - t.Run("success", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Build() - - scan, err := scanner.Add(testData[0].ID) - require.Nil(t, err) - require.Equal(t, testData[0].ID, scan.CourseID) - }) - - t.Run("duplicate", func(t *testing.T) { - scanner, dbManager, logs, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Build() - - firstScan, err := scanner.Add(testData[0].ID) - require.Nil(t, err) - require.Equal(t, testData[0].ID, firstScan.CourseID) - - // Add the same course again - secondScan, err := scanner.Add(testData[0].ID) - require.Nil(t, err) - require.Equal(t, secondScan.ID, firstScan.ID) - require.NotEmpty(t, *logs) - require.Equal(t, "Scan already in progress", (*logs)[len(*logs)-1].Message) - require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) - }) - - t.Run("invalid course", func(t *testing.T) { - scanner, _, _, _ := setupCourseScanner(t) - - scan, err := scanner.Add("test") - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, scan) - }) - - t.Run("not blocked", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(2).Build() - - scan1, err := scanner.Add(testData[0].ID) - require.Nil(t, err) - require.Equal(t, testData[0].ID, scan1.CourseID) - - scan2, err := scanner.Add(testData[1].ID) - require.Nil(t, err) - require.Equal(t, testData[1].ID, scan2.CourseID) - }) - - t.Run("db error", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Build() - scanDao := daos.NewScanDao(dbManager.DataDb) - - _, err := dbManager.DataDb.Exec("DROP TABLE IF EXISTS " + scanDao.Table()) - require.Nil(t, err) - - scan, err := scanner.Add(testData[0].ID) - require.ErrorContains(t, err, fmt.Sprintf("no such table: %s", scanDao.Table())) - require.Nil(t, scan) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseScanner_Worker(t *testing.T) { - t.Run("single job", func(t *testing.T) { - scanner, dbManager, logs, logsMux := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Build() - - // Start the worker - var processingDone = make(chan bool, 1) - go scanner.Worker(func(*CourseScanner, *models.Scan) error { - return nil - }, processingDone) - - // Add the job - scan, err := scanner.Add(testData[0].ID) - require.Nil(t, err) - require.Equal(t, scan.CourseID, testData[0].ID) - - // Wait for the worker to finish - <-processingDone - - // Assert the scan job was deleted from the DB - s, err := scanner.scanDao.Get(testData[0].ID, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, s) - - logsMux.Lock() - defer logsMux.Unlock() - - require.NotEmpty(t, *logs) - require.Equal(t, "Finished processing all jobs", (*logs)[len(*logs)-1].Message) - require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) - }) - - t.Run("several jobs", func(t *testing.T) { - scanner, dbManager, logs, logsMux := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(3).Build() - - for _, course := range testData { - _, err := scanner.Add(course.ID) - require.Nil(t, err) - } - - // Start the worker - var processingDone = make(chan bool, 1) - go scanner.Worker(func(*CourseScanner, *models.Scan) error { - return nil - }, processingDone) - - // Wait for the worker to finish - <-processingDone - - // Assert the scan job was deleted from the DB - for _, course := range testData { - s, err := scanner.scanDao.Get(course.ID, nil) - require.ErrorIs(t, err, sql.ErrNoRows) - require.Nil(t, s) - } - - logsMux.Lock() - defer logsMux.Unlock() - - require.NotEmpty(t, *logs) - require.Equal(t, "Finished processing all jobs", (*logs)[len(*logs)-1].Message) - require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) - }) - - t.Run("error processing", func(t *testing.T) { - scanner, dbManager, logs, logsMux := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Build() - - // Start the worker - var processingDone = make(chan bool, 1) - go scanner.Worker(func(*CourseScanner, *models.Scan) error { - return errors.New("processing error") - }, processingDone) - - scan, err := scanner.Add(testData[0].ID) - require.Nil(t, err) - require.Equal(t, scan.CourseID, testData[0].ID) - - // Wait for the worker to finish - <-processingDone - - logsMux.Lock() - defer logsMux.Unlock() - - require.NotEmpty(t, *logs) - require.Greater(t, len(*logs), 2) - require.Equal(t, "Failed to process scan job", (*logs)[len(*logs)-2].Message) - require.Equal(t, slog.LevelError, (*logs)[len(*logs)-2].Level) - require.Equal(t, "Finished processing all jobs", (*logs)[len(*logs)-1].Message) - require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) - }) - - t.Run("scan error", func(t *testing.T) { - scanner, dbManager, logs, logsMux := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Build() - scanDao := daos.NewScanDao(dbManager.DataDb) - - // Add the job - scan, err := scanner.Add(testData[0].ID) - require.Nil(t, err) - require.Equal(t, scan.CourseID, testData[0].ID) - - // Drop the DB - _, err = dbManager.DataDb.Exec("DROP TABLE IF EXISTS " + scanDao.Table()) - require.Nil(t, err) - - // Start the worker - var processingDone = make(chan bool, 1) - go scanner.Worker(func(*CourseScanner, *models.Scan) error { - return nil - }, processingDone) - - // Wait for the worker to finish - <-processingDone - - logsMux.Lock() - defer logsMux.Unlock() - - require.NotEmpty(t, *logs) - require.Equal(t, "Failed to look up next scan job", (*logs)[len(*logs)-1].Message) - require.Equal(t, slog.LevelError, (*logs)[len(*logs)-1].Level) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseScanner_CourseProcessor(t *testing.T) { - t.Run("scan nil", func(t *testing.T) { - scanner, _, _, _ := setupCourseScanner(t) - - err := CourseProcessor(scanner, nil) - require.EqualError(t, err, "scan cannot be empty") - }) - - t.Run("error getting course", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Scan().Build() - courseDao := daos.NewCourseDao(dbManager.DataDb) - - // Drop the table - _, err := dbManager.DataDb.Exec("DROP TABLE IF EXISTS " + courseDao.Table()) - require.Nil(t, err) - - err = CourseProcessor(scanner, testData[0].Scan) - require.ErrorContains(t, err, fmt.Sprintf("no such table: %s", courseDao.Table())) - }) - - t.Run("course unavailable", func(t *testing.T) { - scanner, dbManager, logs, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Scan().Build() - - // Mark the course as available - testData[0].Available = true - err := scanner.courseDao.Update(testData[0].Course, nil) - require.Nil(t, err) - - err = CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - require.NotEmpty(t, *logs) - require.Equal(t, "Ignoring scan job as the course path does not exist", (*logs)[len(*logs)-1].Message) - require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) - }) - - t.Run("card", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - // ---------------------------- - // Found card - // ---------------------------- - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses([]string{"course 1"}).Scan().Build() - require.Empty(t, testData[0].CardPath) - - scanner.appFs.Fs.Mkdir(testData[0].Path, os.ModePerm) - scanner.appFs.Fs.Create(filepath.Join(testData[0].Path, "card.jpg")) - - err := CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - c, err := scanner.courseDao.Get(testData[0].ID, nil, nil) - require.Nil(t, err) - require.Equal(t, filepath.Join(testData[0].Path, "card.jpg"), c.CardPath) - - // ---------------------------- - // Ignore card in chapter - // ---------------------------- - testData = daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses([]string{"course "}).Scan().Build() - require.Empty(t, testData[0].CardPath) - - scanner.appFs.Fs.Mkdir(testData[0].Path, os.ModePerm) - scanner.appFs.Fs.Create(filepath.Join(testData[0].Path, "chapter 1", "card.jpg")) - - err = CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - c, err = scanner.courseDao.Get(testData[0].ID, nil, nil) - require.Nil(t, err) - require.Empty(t, c.CardPath) - require.Empty(t, testData[0].CardPath) - - // ---------------------------- - // Multiple cards types - // ---------------------------- - testData = daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses([]string{"course 3"}).Scan().Build() - require.Empty(t, testData[0].CardPath) - - scanner.appFs.Fs.Mkdir(testData[0].Path, os.ModePerm) - scanner.appFs.Fs.Create(filepath.Join(testData[0].Path, "card.jpg")) - scanner.appFs.Fs.Create(filepath.Join(testData[0].Path, "card.png")) - - err = CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - c, err = scanner.courseDao.Get(testData[0].ID, nil, nil) - require.Nil(t, err) - require.Equal(t, filepath.Join(testData[0].Path, "card.jpg"), c.CardPath) - }) - - t.Run("card error", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses([]string{"course 4"}).Scan().Build() - require.Empty(t, testData[0].CardPath) - - scanner.appFs.Fs.Mkdir(testData[0].Path, os.ModePerm) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/card.jpg", testData[0].Path)) - - // Rename the card_path column - courseDao := daos.NewCourseDao(dbManager.DataDb) - _, err := dbManager.DataDb.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN card_path TO ignore_card_path", courseDao.Table())) - require.Nil(t, err) - - err = CourseProcessor(scanner, testData[0].Scan) - require.ErrorContains(t, err, "no such column: card_path") - }) - - t.Run("ignore files", func(t *testing.T) { - scanner, dbManager, logs, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Scan().Build() - - scanner.appFs.Fs.Mkdir(testData[0].Path, os.ModePerm) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/file 1", testData[0].Path)) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/file.file", testData[0].Path)) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/file.avi", testData[0].Path)) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/ - file.avi", testData[0].Path)) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/- - file.avi", testData[0].Path)) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/.avi", testData[0].Path)) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/-1 - file.avi", testData[0].Path)) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/a - file.avi", testData[0].Path)) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/1.1 - file.avi", testData[0].Path)) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/2.3-file.avi", testData[0].Path)) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/1file.avi", testData[0].Path)) - - err := CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - assetDao := daos.NewAssetDao(dbManager.DataDb) - - assets, err := scanner.assetDao.List(&database.DatabaseParams{Where: squirrel.Eq{assetDao.Table() + ".course_id": testData[0].ID}}, nil) - require.Nil(t, err) - require.Zero(t, len(assets)) - require.NotEmpty(t, *logs) - require.Equal(t, "Ignoring file during scan job", (*logs)[len(*logs)-1].Message) - require.Equal(t, slog.LevelDebug, (*logs)[len(*logs)-1].Level) - }) - - t.Run("assets", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Scan().Build() - assetDao := daos.NewAssetDao(dbManager.DataDb) - - dbParams := &database.DatabaseParams{ - OrderBy: []string{assetDao.Table() + ".chapter asc", assetDao.Table() + ".prefix asc"}, - Where: squirrel.Eq{assetDao.Table() + ".course_id": testData[0].ID}, - } - - // ---------------------------- - // Add 2 assets - // ---------------------------- - scanner.appFs.Fs.Mkdir(testData[0].Path, os.ModePerm) - - afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/01 file 1.mkv", testData[0].Path), []byte("file 1"), os.ModePerm) - afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/02 file 2.html", testData[0].Path), []byte("file 2"), os.ModePerm) - afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/should ignore", testData[0].Path), []byte("ignore"), os.ModePerm) - - err := CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - assets, err := scanner.assetDao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, assets, 2) - - require.Equal(t, "file 1", assets[0].Title) - require.Equal(t, testData[0].ID, assets[0].CourseID) - require.Equal(t, 1, int(assets[0].Prefix.Int16)) - require.Empty(t, assets[0].Chapter) - require.True(t, assets[0].Type.IsVideo()) - require.Equal(t, "ca934260de4b6eb696e4e9912447bc7f2bd7b614da6879b7addef8e03dca71d1", assets[0].Hash) - - require.Equal(t, "file 2", assets[1].Title) - require.Equal(t, testData[0].ID, assets[1].CourseID) - require.Equal(t, 2, int(assets[1].Prefix.Int16)) - require.Empty(t, assets[1].Chapter) - require.True(t, assets[1].Type.IsHTML()) - require.Equal(t, "21b5bfe70ae6b203182d12bdde12f6f086000e37c894187a47b664ea7ec2331a", assets[1].Hash) - - // ---------------------------- - // Delete asset - // ---------------------------- - scanner.appFs.Fs.Remove(fmt.Sprintf("%s/01 file 1.mkv", testData[0].Path)) - - err = CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - assets, err = scanner.assetDao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, assets, 1) - - require.Equal(t, "file 2", assets[0].Title) - require.Equal(t, testData[0].ID, assets[0].CourseID) - require.Equal(t, 2, int(assets[0].Prefix.Int16)) - require.Empty(t, assets[0].Chapter) - require.True(t, assets[0].Type.IsHTML()) - require.Equal(t, "21b5bfe70ae6b203182d12bdde12f6f086000e37c894187a47b664ea7ec2331a", assets[0].Hash) - - // ---------------------------- - // Add chapter asset - // ---------------------------- - afero.WriteFile(scanner.appFs.Fs, fmt.Sprintf("%s/01 Chapter 1/01 file 3.pdf", testData[0].Path), []byte("file 3"), os.ModePerm) - - err = CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - assets, err = scanner.assetDao.List(dbParams, nil) - require.Nil(t, err) - require.Len(t, assets, 2) - - require.Equal(t, "file 2", assets[0].Title) - require.Equal(t, testData[0].ID, assets[0].CourseID) - require.Equal(t, 2, int(assets[0].Prefix.Int16)) - require.Empty(t, assets[0].Chapter) - require.True(t, assets[0].Type.IsHTML()) - require.Equal(t, "21b5bfe70ae6b203182d12bdde12f6f086000e37c894187a47b664ea7ec2331a", assets[0].Hash) - - require.Equal(t, "file 3", assets[1].Title) - require.Equal(t, testData[0].ID, assets[1].CourseID) - require.Equal(t, 1, int(assets[1].Prefix.Int16)) - require.Equal(t, "01 Chapter 1", assets[1].Chapter) - require.True(t, assets[1].Type.IsPDF()) - require.Equal(t, "333940e348f410361b399939d5e120c72896843ad2bea2e5a961cba6818a9ad9", assets[1].Hash) - }) - - t.Run("assets error", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Scan().Build() - assetDao := daos.NewAssetDao(dbManager.DataDb) - - scanner.appFs.Fs.Mkdir(testData[0].Path, os.ModePerm) - scanner.appFs.Fs.Create(fmt.Sprintf("%s/01 video.mkv", testData[0].Path)) - - _, err := dbManager.DataDb.Exec("DROP TABLE IF EXISTS " + assetDao.Table()) - require.Nil(t, err) - - err = CourseProcessor(scanner, testData[0].Scan) - require.ErrorContains(t, err, "no such table: "+assetDao.Table()) - }) - - t.Run("attachments", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Scan().Build() - assetDao := daos.NewAssetDao(dbManager.DataDb) - attachmentDao := daos.NewAttachmentDao(dbManager.DataDb) - - scanner.appFs.Fs.Mkdir(testData[0].Path, os.ModePerm) - - assetDbParams := &database.DatabaseParams{ - Where: squirrel.Eq{assetDao.Table() + ".course_id": testData[0].ID}, - IncludeRelations: []string{attachmentDao.Table()}, - } - - attachmentDbParams := &database.DatabaseParams{ - OrderBy: []string{"created_at asc"}, - Where: squirrel.Eq{attachmentDao.Table() + ".course_id": testData[0].ID}, - } - - // ---------------------------- - // Add 1 asset with 1 attachment - // ---------------------------- - afero.WriteFile(scanner.appFs.Fs, filepath.Join(testData[0].Path, "01 video.mp4"), []byte("video"), os.ModePerm) - afero.WriteFile(scanner.appFs.Fs, filepath.Join(testData[0].Path, "01 info.txt"), []byte("info"), os.ModePerm) - - err := CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - assets, err := scanner.assetDao.List(assetDbParams, nil) - require.Nil(t, err) - require.Len(t, assets, 1) - require.Equal(t, "video", assets[0].Title) - require.Equal(t, testData[0].ID, assets[0].CourseID) - require.Equal(t, 1, int(assets[0].Prefix.Int16)) - require.Equal(t, filepath.Join(testData[0].Path, "01 video.mp4"), assets[0].Path) - require.True(t, assets[0].Type.IsVideo()) - require.Equal(t, "e56ca866bff1691433766c60304a96583c1a410e53b33ef7d89cb29eac2a97ab", assets[0].Hash) - - attachments, err := scanner.attachmentDao.List(attachmentDbParams, nil) - require.Nil(t, err) - require.Len(t, attachments, 1) - require.Equal(t, "info.txt", attachments[0].Title) - require.Equal(t, filepath.Join(testData[0].Path, "01 info.txt"), attachments[0].Path) - - // ---------------------------- - // Add another attachment - // ---------------------------- - afero.WriteFile(scanner.appFs.Fs, filepath.Join(testData[0].Path, "01 code.zip"), []byte("code"), os.ModePerm) - - err = CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - // Assert the asset - assets, err = scanner.assetDao.List(assetDbParams, nil) - require.Nil(t, err) - require.Len(t, assets, 1) - require.Len(t, assets[0].Attachments, 2) - - attachments, err = scanner.attachmentDao.List(attachmentDbParams, nil) - require.Nil(t, err) - require.Len(t, attachments, 2) - require.Equal(t, "info.txt", attachments[0].Title) - require.Equal(t, filepath.Join(testData[0].Path, "01 info.txt"), attachments[0].Path) - require.Equal(t, "code.zip", attachments[1].Title) - require.Equal(t, filepath.Join(testData[0].Path, "01 code.zip"), attachments[1].Path) - - // ---------------------------- - // Delete first attachment - // ---------------------------- - scanner.appFs.Fs.Remove(filepath.Join(testData[0].Path, "01 info.txt")) - - err = CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - // Assert the asset - assets, err = scanner.assetDao.List(assetDbParams, nil) - require.Nil(t, err) - require.Len(t, assets, 1) - require.Equal(t, "video", assets[0].Title) - require.Len(t, assets[0].Attachments, 1) - - attachments, err = scanner.attachmentDao.List(attachmentDbParams, nil) - require.Nil(t, err) - require.Len(t, attachments, 1) - require.Equal(t, "code.zip", attachments[0].Title) - require.Equal(t, filepath.Join(testData[0].Path, "01 code.zip"), attachments[0].Path) - }) - - t.Run("attachments error", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Scan().Build() - attachmentDao := daos.NewAttachmentDao(dbManager.DataDb) - - scanner.appFs.Fs.Mkdir(testData[0].Path, os.ModePerm) - scanner.appFs.Fs.Create(filepath.Join(testData[0].Path, "01 video.mkv")) - scanner.appFs.Fs.Create(filepath.Join(testData[0].Path, "01 info")) - - // Drop the attachments table - _, err := dbManager.DataDb.Exec("DROP TABLE IF EXISTS " + attachmentDao.Table()) - require.Nil(t, err) - - err = CourseProcessor(scanner, testData[0].Scan) - require.ErrorContains(t, err, "no such table: "+attachmentDao.Table()) - }) - - t.Run("asset priority", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - // ---------------------------- - // Priority is VIDEO -> HTML -> PDF - // ---------------------------- - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Scan().Build() - assetDao := daos.NewAssetDao(dbManager.DataDb) - attachmentDao := daos.NewAttachmentDao(dbManager.DataDb) - - scanner.appFs.Fs.Mkdir(testData[0].Path, os.ModePerm) - - assetDbParams := &database.DatabaseParams{ - Where: squirrel.Eq{assetDao.Table() + ".course_id": testData[0].ID}, - IncludeRelations: []string{attachmentDao.Table()}, - } - - attachmentDbParams := &database.DatabaseParams{ - OrderBy: []string{"created_at asc"}, - Where: squirrel.Eq{attachmentDao.Table() + ".course_id": testData[0].ID}, - } - - // ---------------------------- - // Add PDF (asset) - // ---------------------------- - pdfFile := filepath.Join(testData[0].Path, "01 doc 1.pdf") - afero.WriteFile(scanner.appFs.Fs, pdfFile, []byte("doc 1"), os.ModePerm) - - err := CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - assets, err := scanner.assetDao.List(assetDbParams, nil) - require.Nil(t, err) - require.Len(t, assets, 1) - require.Equal(t, pdfFile, assets[0].Path) - require.True(t, assets[0].Type.IsPDF()) - require.Equal(t, "61363a1cb5bf5514e3f9e983b6a96aeb12dd1ccff1b19938231d6b798d5832f9", assets[0].Hash) - require.Empty(t, assets[0].Attachments) - - // ---------------------------- - // Add HTML (asset) - // ---------------------------- - htmlFile := filepath.Join(testData[0].Path, "01 example.html") - afero.WriteFile(scanner.appFs.Fs, htmlFile, []byte("example"), os.ModePerm) - - err = CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - assets, err = scanner.assetDao.List(assetDbParams, nil) - require.Nil(t, err) - require.Len(t, assets, 1) - require.Equal(t, htmlFile, assets[0].Path) - require.True(t, assets[0].Type.IsHTML()) - require.Equal(t, "43aeba97fea3bfc61a897ca37b73e79c74b2ff6ea792446764a1daf65784c971", assets[0].Hash) - require.Len(t, assets[0].Attachments, 1) - - attachments, err := scanner.attachmentDao.List(attachmentDbParams, nil) - require.Nil(t, err) - require.Len(t, attachments, 1) - require.Equal(t, pdfFile, attachments[0].Path) - - // ---------------------------- - // Add VIDEO (asset) - // ---------------------------- - videoFile := filepath.Join(testData[0].Path, "01 video.mp4") - afero.WriteFile(scanner.appFs.Fs, videoFile, []byte("video"), os.ModePerm) - - err = CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - assets, err = scanner.assetDao.List(assetDbParams, nil) - require.Nil(t, err) - require.Len(t, assets, 1) - require.Equal(t, videoFile, assets[0].Path) - require.True(t, assets[0].Type.IsVideo()) - require.Equal(t, "e56ca866bff1691433766c60304a96583c1a410e53b33ef7d89cb29eac2a97ab", assets[0].Hash) - require.Len(t, assets[0].Attachments, 2) - - attachments, err = scanner.attachmentDao.List(attachmentDbParams, nil) - require.Nil(t, err) - require.Len(t, attachments, 2) - require.Equal(t, pdfFile, attachments[0].Path) - require.Equal(t, htmlFile, attachments[1].Path) - - // ---------------------------- - // Add PDF file (attachment) - // ---------------------------- - pdfFile2 := filepath.Join(testData[0].Path, "/01 - e.pdf") - afero.WriteFile(scanner.appFs.Fs, pdfFile2, []byte("e"), os.ModePerm) - - err = CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - assets, err = scanner.assetDao.List(assetDbParams, nil) - require.Nil(t, err) - require.Len(t, assets, 1) - require.Equal(t, videoFile, assets[0].Path) - require.True(t, assets[0].Type.IsVideo()) - require.Equal(t, "e56ca866bff1691433766c60304a96583c1a410e53b33ef7d89cb29eac2a97ab", assets[0].Hash) - require.Len(t, assets[0].Attachments, 3) - - attachments, err = scanner.attachmentDao.List(attachmentDbParams, nil) - require.Nil(t, err) - require.Len(t, attachments, 3) - require.Equal(t, pdfFile, attachments[0].Path) - require.Equal(t, htmlFile, attachments[1].Path) - require.Equal(t, pdfFile2, attachments[2].Path) - }) - - t.Run("course updated", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Scan().Build() - - scanner.appFs.Fs.Mkdir(testData[0].Path, os.ModePerm) - - err := CourseProcessor(scanner, testData[0].Scan) - require.Nil(t, err) - - updatedCourse, err := scanner.courseDao.Get(testData[0].ID, nil, nil) - require.Nil(t, err) - require.NotEqual(t, testData[0].UpdatedAt, updatedCourse.UpdatedAt) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseScanner_BuildFileInfo(t *testing.T) { - t.Run("invalid", func(t *testing.T) { - var tests = []string{ - // No prefix - "file", - "file.file", - "file.avi", - " - file.avi", - "- - file.avi", - ".avi", - // Invalid prefix - "-1 - file.avi", - "a - file.avi", - "1.1 - file.avi", - "2.3-file.avi", - "1file.avi", - } - - for _, tt := range tests { - fb := parseFileName(tt) - require.Nil(t, fb) - } - }) - - t.Run("assets", func(t *testing.T) { - var tests = []struct { - in string - expected *parsedFileName - }{ - // Video (with varied filenames) - {"0 file 0.avi", &parsedFileName{prefix: 0, title: "file 0", ext: "avi", attachmentTitle: "file 0.avi", asset: types.NewAsset("avi")}}, - {"001 file 1.mp4", &parsedFileName{prefix: 1, title: "file 1", ext: "mp4", attachmentTitle: "file 1.mp4", asset: types.NewAsset("mp4")}}, - {"1-file.ogg", &parsedFileName{prefix: 1, title: "file", ext: "ogg", attachmentTitle: "file.ogg", asset: types.NewAsset("ogg")}}, - {"2 - file.webm", &parsedFileName{prefix: 2, title: "file", ext: "webm", attachmentTitle: "file.webm", asset: types.NewAsset("webm")}}, - {"3 -file.m4a", &parsedFileName{prefix: 3, title: "file", ext: "m4a", attachmentTitle: "file.m4a", asset: types.NewAsset("m4a")}}, - {"4- file.opus", &parsedFileName{prefix: 4, title: "file", ext: "opus", attachmentTitle: "file.opus", asset: types.NewAsset("opus")}}, - {"5000 --- file.wav", &parsedFileName{prefix: 5000, title: "file", ext: "wav", attachmentTitle: "file.wav", asset: types.NewAsset("wav")}}, - {"0100 file.mp3", &parsedFileName{prefix: 100, title: "file", ext: "mp3", attachmentTitle: "file.mp3", asset: types.NewAsset("mp3")}}, - // PDF - {"1 - doc.pdf", &parsedFileName{prefix: 1, title: "doc", ext: "pdf", attachmentTitle: "doc.pdf", asset: types.NewAsset("pdf")}}, - // HTML - {"1 index.html", &parsedFileName{prefix: 1, title: "index", ext: "html", attachmentTitle: "index.html", asset: types.NewAsset("html")}}, - } - - for _, tt := range tests { - fb := parseFileName(tt.in) - require.Equal(t, tt.expected, fb, fmt.Sprintf("error for [%s]", tt.in)) - } - }) - - t.Run("attachments", func(t *testing.T) { - var tests = []struct { - in string - expected *parsedFileName - }{ - // No title - {"01", &parsedFileName{prefix: 1, title: "", attachmentTitle: "01"}}, - {"200.pdf", &parsedFileName{prefix: 200, title: "", attachmentTitle: "200.pdf"}}, - {"1 -.txt", &parsedFileName{prefix: 1, title: "", attachmentTitle: "1 -.txt"}}, - {"1 .txt", &parsedFileName{prefix: 1, title: "", attachmentTitle: "1 .txt"}}, - {"1 .pdf", &parsedFileName{prefix: 1, title: "", attachmentTitle: "1 .pdf"}}, - // No extension (fileName should have no prefix) - {"0 file 0", &parsedFileName{prefix: 0, title: "file 0", attachmentTitle: "file 0"}}, - {"001 file 1", &parsedFileName{prefix: 1, title: "file 1", attachmentTitle: "file 1"}}, - {"1001 - file", &parsedFileName{prefix: 1001, title: "file", attachmentTitle: "file"}}, - {"0123-file", &parsedFileName{prefix: 123, title: "file", attachmentTitle: "file"}}, - {"1 --- file", &parsedFileName{prefix: 1, title: "file", attachmentTitle: "file"}}, - // Non-asset extension (fileName should have no prefix) - {"1 file.txt", &parsedFileName{prefix: 1, title: "file", ext: "txt", attachmentTitle: "file.txt"}}, - } - - for _, tt := range tests { - fb := parseFileName(tt.in) - require.Equal(t, tt.expected, fb, fmt.Sprintf("error for [%s]", tt.in)) - } - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseScanner_IsCard(t *testing.T) { - t.Run("invalid", func(t *testing.T) { - var tests = []string{ - "card", - "1234", - "1234.jpg", - "jpg", - "card.test.jpg", - "card.txt", - } - - for _, tt := range tests { - require.False(t, isCard(tt)) - } - }) - - t.Run("valid", func(t *testing.T) { - var tests = []string{ - "card.jpg", - "card.jpeg", - "card.png", - "card.webp", - "card.tiff", - } - - for _, tt := range tests { - require.True(t, isCard(tt)) - } - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseScanner_UpdateAssets(t *testing.T) { - t.Run("nothing added or deleted", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Assets(10).Build() - assetDao := daos.NewAssetDao(dbManager.DataDb) - - err := updateAssets(scanner.assetDao, nil, testData[0].ID, testData[0].Assets) - require.Nil(t, err) - - dbParams := &database.DatabaseParams{Where: squirrel.Eq{assetDao.Table() + ".course_id": testData[0].ID}} - count, err := scanner.assetDao.Count(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 10, count) - }) - - t.Run("add", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Assets(12).Build() - assetDao := daos.NewAssetDao(dbManager.DataDb) - - // Delete the assets (so we can add them again) - for _, a := range testData[0].Assets { - require.Nil(t, scanner.assetDao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": a.ID}}, nil)) - } - - dbParams := &database.DatabaseParams{Where: squirrel.Eq{assetDao.Table() + ".course_id": testData[0].ID}} - - // ---------------------------- - // Add 10 assets - // ---------------------------- - err := updateAssets(scanner.assetDao, nil, testData[0].ID, testData[0].Assets[:10]) - require.Nil(t, err) - - count, err := scanner.assetDao.Count(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 10, count) - - // ---------------------------- - // Add another 2 assets - // ---------------------------- - testData[0].Assets[10].ID = "" - testData[0].Assets[11].ID = "" - - err = updateAssets(scanner.assetDao, nil, testData[0].ID, testData[0].Assets) - require.Nil(t, err) - - count, err = scanner.assetDao.Count(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 12, count) - - // Ensure all assets have an ID - for _, a := range testData[0].Assets { - require.NotEmpty(t, a.ID) - } - }) - - t.Run("delete", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Assets(12).Build() - - assetDao := daos.NewAssetDao(dbManager.DataDb) - - dbParams := &database.DatabaseParams{Where: squirrel.Eq{assetDao.Table() + ".course_id": testData[0].ID}} - - // ---------------------------- - // Remove 2 assets - // ---------------------------- - testData[0].Assets = testData[0].Assets[2:] - - err := updateAssets(scanner.assetDao, nil, testData[0].ID, testData[0].Assets) - require.Nil(t, err) - - count, err := scanner.assetDao.Count(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 10, count) - - // ---------------------------- - // Remove another 2 assets - // ---------------------------- - testData[0].Assets = testData[0].Assets[2:] - - err = updateAssets(scanner.assetDao, nil, testData[0].ID, testData[0].Assets) - require.Nil(t, err) - - count, err = scanner.assetDao.Count(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 8, count) - }) - - t.Run("rename", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Assets(12).Build() - assetDao := daos.NewAssetDao(dbManager.DataDb) - - // Delete the assets (so we can add them again) - for _, a := range testData[0].Assets { - require.Nil(t, scanner.assetDao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": a.ID}}, nil)) - } - - dbParams := &database.DatabaseParams{Where: squirrel.Eq{assetDao.Table() + ".course_id": testData[0].ID}} - - // ---------------------------- - // Add assets - // ---------------------------- - err := updateAssets(scanner.assetDao, nil, testData[0].ID, testData[0].Assets) - require.Nil(t, err) - - count, err := scanner.assetDao.Count(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 12, count) - - // ---------------------------- - // Rename 2 assets - // ---------------------------- - testData[0].Assets[10].Prefix = sql.NullInt16{Int16: int16(rand.Intn(100-1) + 1), Valid: true} - testData[0].Assets[10].Title = security.PseudorandomString(7) - testData[0].Assets[10].Chapter = security.PseudorandomString(7) - testData[0].Assets[10].Path = fmt.Sprintf("%s/%s/%d %s.mp4", testData[0].Path, testData[0].Assets[10].Chapter, testData[0].Assets[10].Prefix.Int16, testData[0].Assets[10].Title) - - testData[0].Assets[11].Prefix = sql.NullInt16{Int16: int16(rand.Intn(100-1) + 1), Valid: true} - testData[0].Assets[11].Title = security.PseudorandomString(7) - testData[0].Assets[11].Chapter = security.PseudorandomString(7) - testData[0].Assets[11].Path = fmt.Sprintf("%s/%s/%d %s.mp4", testData[0].Path, testData[0].Assets[11].Chapter, testData[0].Assets[11].Prefix.Int16, testData[0].Assets[11].Title) - - err = updateAssets(scanner.assetDao, nil, testData[0].ID, testData[0].Assets) - require.Nil(t, err) - - a, err := scanner.assetDao.List(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 12, len(a)) - - // Ensure the assets were updated - // - // Note: Order is not guaranteed, so we have to validate this way - for _, asset := range a { - if asset.ID == testData[0].Assets[10].ID { - require.Equal(t, testData[0].Assets[10].Title, asset.Title) - require.Equal(t, testData[0].Assets[10].Prefix, asset.Prefix) - require.Equal(t, testData[0].Assets[10].Chapter, asset.Chapter) - require.Equal(t, testData[0].Assets[10].Path, asset.Path) - } else if asset.ID == testData[0].Assets[11].ID { - require.Equal(t, testData[0].Assets[11].Title, asset.Title) - require.Equal(t, testData[0].Assets[11].Prefix, asset.Prefix) - require.Equal(t, testData[0].Assets[11].Chapter, asset.Chapter) - require.Equal(t, testData[0].Assets[11].Path, asset.Path) - } - } - }) - - t.Run("swap", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Assets(2).Build() - - assetDao := daos.NewAssetDao(dbManager.DataDb) - dbParams := &database.DatabaseParams{Where: squirrel.Eq{assetDao.Table() + ".course_id": testData[0].ID}} - - // ---------------------------- - // Swap assets title and path - // ---------------------------- - testData[0].Assets[0].Title, testData[0].Assets[1].Title = testData[0].Assets[1].Title, testData[0].Assets[0].Title - testData[0].Assets[0].Path, testData[0].Assets[1].Path = testData[0].Assets[1].Path, testData[0].Assets[0].Path - - err := updateAssets(scanner.assetDao, nil, testData[0].ID, testData[0].Assets) - require.Nil(t, err) - - a, err := scanner.assetDao.List(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 2, len(a)) - - // Ensure the assets were updated - // - // Note: Order is not guaranteed, so we have to validate this way - for _, asset := range a { - if asset.ID == testData[0].Assets[0].ID { - require.Equal(t, testData[0].Assets[0].Title, asset.Title) - require.Equal(t, testData[0].Assets[0].Path, asset.Path) - } else if asset.ID == testData[0].Assets[1].ID { - require.Equal(t, testData[0].Assets[1].Title, asset.Title) - require.Equal(t, testData[0].Assets[1].Path, asset.Path) - } - } - }) - - t.Run("db error", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - assetDao := daos.NewAssetDao(dbManager.DataDb) - - // Drop the table - _, err := dbManager.DataDb.Exec("DROP TABLE IF EXISTS " + assetDao.Table()) - require.Nil(t, err) - - err = updateAssets(scanner.assetDao, nil, "1234", []*models.Asset{}) - require.ErrorContains(t, err, fmt.Sprintf("no such table: %s", assetDao.Table())) - }) -} - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -func TestCourseScanner_UpdateAttachments(t *testing.T) { - t.Run("nothing added or delete)", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Assets(1).Attachments(10).Build() - attDao := daos.NewAttachmentDao(dbManager.DataDb) - - err := updateAttachments(scanner.attachmentDao, nil, testData[0].ID, testData[0].Assets[0].Attachments) - require.Nil(t, err) - - count, err := scanner.attachmentDao.Count(&database.DatabaseParams{Where: squirrel.Eq{attDao.Table() + ".course_id": testData[0].ID}}, nil) - require.Nil(t, err) - require.Equal(t, 10, count) - }) - - t.Run("add", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Assets(1).Attachments(12).Build() - attDao := daos.NewAttachmentDao(dbManager.DataDb) - - // Delete the attachments (so we can add them again) - for _, a := range testData[0].Assets[0].Attachments { - require.Nil(t, scanner.attachmentDao.Delete(&database.DatabaseParams{Where: squirrel.Eq{"id": a.ID}}, nil)) - } - - dbParams := &database.DatabaseParams{Where: squirrel.Eq{attDao.Table() + ".course_id": testData[0].ID}} - - // ---------------------------- - // Add 10 attachments - // ---------------------------- - err := updateAttachments(scanner.attachmentDao, nil, testData[0].ID, testData[0].Assets[0].Attachments[:10]) - require.Nil(t, err) - - count, err := scanner.attachmentDao.Count(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 10, count) - - // ---------------------------- - // Add another 2 attachments - // ---------------------------- - err = updateAttachments(scanner.attachmentDao, nil, testData[0].ID, testData[0].Assets[0].Attachments) - require.Nil(t, err) - - count, err = scanner.attachmentDao.Count(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 12, count) - }) - - t.Run("delete", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - testData := daos.NewTestBuilder(t).Db(dbManager.DataDb).Courses(1).Assets(1).Attachments(12).Build() - attachmentDao := daos.NewAttachmentDao(dbManager.DataDb) - - dbParams := &database.DatabaseParams{Where: squirrel.Eq{attachmentDao.Table() + ".course_id": testData[0].ID}} - - // ---------------------------- - // Remove 2 attachments - // ---------------------------- - testData[0].Assets[0].Attachments = testData[0].Assets[0].Attachments[2:] - - err := updateAttachments(scanner.attachmentDao, nil, testData[0].ID, testData[0].Assets[0].Attachments) - require.Nil(t, err) - - count, err := scanner.attachmentDao.Count(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 10, count) - - // ---------------------------- - // Remove another 2 attachments - // ---------------------------- - testData[0].Assets[0].Attachments = testData[0].Assets[0].Attachments[2:] - - err = updateAttachments(scanner.attachmentDao, nil, testData[0].ID, testData[0].Assets[0].Attachments) - require.Nil(t, err) - - count, err = scanner.attachmentDao.Count(dbParams, nil) - require.Nil(t, err) - require.Equal(t, 8, count) - }) - - t.Run("db error", func(t *testing.T) { - scanner, dbManager, _, _ := setupCourseScanner(t) - - attachmentDao := daos.NewAttachmentDao(dbManager.DataDb) - - // Drop the table - _, err := dbManager.DataDb.Exec("DROP TABLE IF EXISTS " + attachmentDao.Table()) - require.Nil(t, err) - - err = updateAttachments(scanner.attachmentDao, nil, "1234", []*models.Attachment{}) - require.ErrorContains(t, err, fmt.Sprintf("no such table: %s", attachmentDao.Table())) - }) -} diff --git a/utils/logger/batch.go b/utils/logger/batch.go index 40b16d5..eff4d96 100644 --- a/utils/logger/batch.go +++ b/utils/logger/batch.go @@ -190,8 +190,13 @@ func (h *BatchHandler) Handle(ctx context.Context, r slog.Record) error { return true }) + t, err := types.ParseDateTime(r.Time) + if err != nil { + return err + } + log := &Log{ - Time: r.Time, + Time: t, Level: r.Level, Message: r.Message, Data: types.JsonMap(data), diff --git a/utils/logger/logger.go b/utils/logger/logger.go index a50b53a..911cca1 100644 --- a/utils/logger/logger.go +++ b/utils/logger/logger.go @@ -13,7 +13,7 @@ import ( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ type Log struct { - Time time.Time + Time types.DateTime Message string Level slog.Level Data types.JsonMap diff --git a/utils/pagination/base.go b/utils/pagination/base.go index 59e5370..17e8f8b 100644 --- a/utils/pagination/base.go +++ b/utils/pagination/base.go @@ -121,7 +121,6 @@ func (p *Pagination) Apply(queryBuilder sq.SelectBuilder) sq.SelectBuilder { // BuildResult builds a result object from the pagination values, which is suitable for a HTTP // response func (p *Pagination) BuildResult(m any) (*PaginationResult, error) { - // Slice to hold the marshaled items items := []json.RawMessage{} diff --git a/utils/pagination/base_test.go b/utils/pagination/base_test.go index d5812f8..a1f445d 100644 --- a/utils/pagination/base_test.go +++ b/utils/pagination/base_test.go @@ -4,9 +4,9 @@ import ( "encoding/json" "fmt" "testing" - "time" sq "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/utils/types" "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" @@ -230,25 +230,25 @@ func Test_BuildResult(t *testing.T) { p.SetCount(24) type Data struct { - ID string `json:"id"` - CreatedAt time.Time `json:"createdAt"` + ID string `json:"id"` + CreatedAt types.DateTime `json:"createdAt"` } // The data to marshal data := []Data{ - {ID: "1", CreatedAt: time.Now()}, - {ID: "2", CreatedAt: time.Now()}, + {ID: "1", CreatedAt: types.NowDateTime()}, + {ID: "2", CreatedAt: types.NowDateTime()}, } result, err := p.BuildResult(data) - require.Nil(t, err) + require.NoError(t, err) require.Len(t, result.Items, 2) for i, raw := range result.Items { var d Data require.Nil(t, json.Unmarshal(raw, &d)) require.Equal(t, data[i].ID, d.ID) - require.True(t, d.CreatedAt.Equal(data[i].CreatedAt)) + require.Equal(t, data[i].CreatedAt.String(), d.CreatedAt.String()) } }) @@ -301,7 +301,7 @@ func Test_Apply(t *testing.T) { builder = p.Apply(builder) query, args, err := builder.ToSql() - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, "SELECT * FROM dummy LIMIT 10 OFFSET 0", query) require.Nil(t, args) } diff --git a/utils/schema/field.go b/utils/schema/field.go new file mode 100644 index 0000000..ebcaa6b --- /dev/null +++ b/utils/schema/field.go @@ -0,0 +1,152 @@ +package schema + +import ( + "reflect" + + "github.com/geerew/off-course/utils" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +type field struct { + // The name of the struct field + Name string + + // The position of the field in the struct + Position []int + + // The name of the column in the database + Column string + + // A db column alias + Alias string + + // When true, the field cannot be null in the database + NotNull bool + + // When true, the field can be updated + Mutable bool + + // When true, the field will be skipped during create if it is null + IgnoreIfNull bool + + // The table the field belongs to. When empty it belongs to the main table + JoinTable string + + // ReflectValueOf is a callback that takes a struct, as a reflect.Value, and returns the + // reflect.Value of the field + ReflectValueOf func(reflect.Value) reflect.Value + + // ValueOf is a callback that takes a struct, as a reflect.Value, and returns the actual + // value of the field and whether it is zero + ValueOf func(reflect.Value) (any, bool) + + // The concrete type of the field + concreteType reflect.Type +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// parseField will parse a field from the struct and return a field struct +func parseField(sf reflect.StructField, config *modelFieldConfig) *field { + f := &field{ + Name: sf.Name, + Position: sf.Index, + Column: config.column, + Alias: config.alias, + NotNull: config.notNull, + Mutable: config.mutable, + IgnoreIfNull: config.ignoreIfNull, + JoinTable: config.joinTable, + concreteType: sf.Type, + } + + if f.Column == "" { + f.Column = utils.SnakeCase(sf.Name) + } + + // Drill down to the concrete type + for f.concreteType.Kind() == reflect.Ptr { + f.concreteType = f.concreteType.Elem() + } + + f.setReflectValueOf() + f.setValueOf() + + return f +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// parseEmbeddedField will validate an embedded field implements the Modeler interface and then +// parses the fields of the embedded struct +func parseEmbeddedField(sf reflect.StructField) ([]*field, error) { + if sf.Type.Kind() != reflect.Struct { + return nil, utils.ErrEmbedded + } + + m, isDefiner := reflect.New(sf.Type).Interface().(Definer) + if isDefiner { + config := &ModelConfig{} + m.Define(config) + + rt := reflect.Indirect(reflect.ValueOf(m)).Type() + + fields := []*field{} + for i := range rt.NumField() { + sf := rt.Field(i) + + for sf.Type.Kind() == reflect.Ptr { + sf.Type = sf.Type.Elem() + } + + if fieldConfig, ok := config.fields[sf.Name]; ok { + fields = append(fields, parseField(sf, fieldConfig)) + } else if _, ok := config.embedded[sf.Name]; ok { + fs, err := parseEmbeddedField(sf) + if err != nil { + return nil, err + } + + fields = append(fields, fs...) + } + } + + for _, field := range fields { + field.Position = append(sf.Index, field.Position...) + } + + return fields, nil + } else { + return nil, utils.ErrEmbedded + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// SetReflectValueOf sets the ReflectValueOf callback for the field +func (f *field) setReflectValueOf() { + f.ReflectValueOf = func(rv reflect.Value) reflect.Value { + value := reflect.Indirect(rv) + if len(f.Position) == 1 { + value = value.Field(f.Position[0]) + } else { + for _, p := range f.Position { + value = value.Field(p) + } + } + + return value + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// SetValueOf sets the ValueOf callback for the field. The callback will return the actual value +// of the field from the given struct and if the value is zero +func (f *field) setValueOf() { + f.ValueOf = func(rv reflect.Value) (any, bool) { + value := f.ReflectValueOf(rv) + return value.Interface(), value.IsZero() + } +} diff --git a/utils/schema/helper.go b/utils/schema/helper.go new file mode 100644 index 0000000..a24131e --- /dev/null +++ b/utils/schema/helper.go @@ -0,0 +1,191 @@ +package schema + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// BASE +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// TestBase defines a test base struct +type TestBase struct { + ID int +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Defined implements the `Definer` interface +func (b *TestBase) Define(c *ModelConfig) { + c.Field("ID") +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// POST +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// TestPost defines a test post struct +type TestPost struct { + TestBase + UserID int + Title string + Content string +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `Modeler` interface +func (p *TestPost) Table() string { + return "posts" + +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Define implements the `Modeler` interface +func (p *TestPost) Define(c *ModelConfig) { + c.Embedded("TestBase") + + c.Field("UserID").NotNull() + c.Field("Title").NotNull().Mutable() + c.Field("Content").NotNull().Mutable() +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// PROFILE +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// TestProfile defines a test profile struct +type TestProfile struct { + TestBase + UserID int + Name string + Username string + Email string +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `Modeler` interface +func (p *TestProfile) Table() string { + return "profiles" +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Define implements the `Modeler` interface +func (p *TestProfile) Define(c *ModelConfig) { + c.Embedded("TestBase") + + c.Field("UserID").NotNull() + c.Field("Name").NotNull().Mutable() + c.Field("Username").NotNull().Mutable() + c.Field("Email").NotNull().Mutable() +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// USER +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +type TestUser struct { + TestBase + Profile TestProfile + Posts []TestPost + PtrPosts *[]*TestPost +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Table implements the `Modeler` interface +func (u *TestUser) Table() string { + return "users" +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Define implements the `Modeler` interface +func (u *TestUser) Define(c *ModelConfig) { + c.Embedded("TestBase") + + c.Relation("Profile").MatchOn("user_id") + c.Relation("Posts").MatchOn("user_id") + c.Relation("PtrPosts").MatchOn("user_id") +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// setup creates a new in-memory SQLite database +func setup(tb testing.TB) *sql.DB { + tb.Helper() + + db, err := sql.Open("sqlite3", ":memory:") + require.NoError(tb, err) + + // Create tables + _, err = db.Exec(` + -- Create the users table + CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT + ); + + -- Create the profiles table + CREATE TABLE profiles ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + name TEXT NOT NULL, + username TEXT NOT NULL, + email TEXT NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) + ); + + -- Create the posts table + CREATE TABLE posts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + title TEXT NOT NULL, + content TEXT NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) + ); + `) + require.NoError(tb, err) + + // Schemas + userSchema, err := Parse(&TestUser{}) + require.NoError(tb, err) + + profileSchema, err := Parse(&TestProfile{}) + require.NoError(tb, err) + + postSchema, err := Parse(&TestPost{}) + require.NoError(tb, err) + + // Insert John with a profile and 2 posts + john := &TestUser{TestBase: TestBase{ID: 1}} + _, err = userSchema.Insert(john, db) + require.NoError(tb, err) + + _, err = profileSchema.Insert(&TestProfile{TestBase: TestBase{ID: 1}, UserID: john.ID, Name: "John", Username: "john_doe", Email: "john@test.com"}, db) + require.NoError(tb, err) + + _, err = postSchema.Insert(&TestPost{TestBase: TestBase{ID: 1}, UserID: john.ID, Title: "Post 1 by John", Content: "This is the first post by John."}, db) + require.NoError(tb, err) + + _, err = postSchema.Insert(&TestPost{TestBase: TestBase{ID: 2}, UserID: john.ID, Title: "Post 2 by John", Content: "This is the second post by John."}, db) + require.NoError(tb, err) + + // Insert Jane with a profile and 1 post + jane := &TestUser{TestBase: TestBase{ID: 2}} + _, err = userSchema.Insert(jane, db) + require.NoError(tb, err) + + _, err = profileSchema.Insert(&TestProfile{TestBase: TestBase{ID: 2}, UserID: jane.ID, Name: "Jane", Username: "jane_doe", Email: "jane@test.com"}, db) + require.NoError(tb, err) + + _, err = postSchema.Insert(&TestPost{TestBase: TestBase{ID: 3}, UserID: jane.ID, Title: "Post by Jane", Content: "This is the post by Jane."}, db) + require.NoError(tb, err) + + return db +} diff --git a/utils/schema/modeler.go b/utils/schema/modeler.go new file mode 100644 index 0000000..094907f --- /dev/null +++ b/utils/schema/modeler.go @@ -0,0 +1,195 @@ +package schema + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Modeler defines the interface that each struct (model) should implement in order to be used +type Modeler interface { + Table() string + Definer +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Definer defines the interface that each struct (model) should implement in order to be used +type Definer interface { + Define(*ModelConfig) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// ModelConfig defines the configuration for the model +type ModelConfig struct { + // Embedded fields + embedded map[string]struct{} + + // Common fields + fields map[string]*modelFieldConfig + + // Relations + relations map[string]*modelRelationConfig + + // A slice of left joins + leftJoins []*modelLeftJoinConfig +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Embedded field +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Embedded adds an embedded field +func (m *ModelConfig) Embedded(name string) { + + if m.embedded == nil { + m.embedded = make(map[string]struct{}) + } + + m.embedded[name] = struct{}{} +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Simple field +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// modelFieldConfig defines the configuration for a field in the model +type modelFieldConfig struct { + // The name of the struct field + name string + // Override for the db column name + column string + // When true, the field cannot be null in the database + notNull bool + // A db column alias + alias string + // When true, the field can be updated + mutable bool + // When true, the field will be skipped during create if it is null + ignoreIfNull bool + // The table the field belongs to. When empty it belongs to the main table + joinTable string +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Field adds a common field to the fields map +func (m *ModelConfig) Field(name string) *modelFieldConfig { + field := &modelFieldConfig{name: name} + + if m.fields == nil { + m.fields = make(map[string]*modelFieldConfig) + } + + m.fields[name] = field + return field +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Column sets the name of the column in the database +func (f *modelFieldConfig) Column(name string) *modelFieldConfig { + f.column = name + return f +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// NotNull signals that the field cannot be null in the database +func (f *modelFieldConfig) NotNull() *modelFieldConfig { + f.notNull = true + return f +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Alias sets the alias for the column +func (f *modelFieldConfig) Alias(name string) *modelFieldConfig { + f.alias = name + return f +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Mutable signals that the field can be updated +func (f *modelFieldConfig) Mutable() *modelFieldConfig { + f.mutable = true + return f +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// IgnoreIfNull signals that the field should be skipped during create if it is null +func (f *modelFieldConfig) IgnoreIfNull() *modelFieldConfig { + f.ignoreIfNull = true + return f +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// JoinTable sets the table the field belongs to +func (f *modelFieldConfig) JoinTable(table string) *modelFieldConfig { + f.joinTable = table + return f +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Relation field +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// modelFieldConfig defines the configuration for a field in the model +type modelRelationConfig struct { + // The name of the struct field + name string + // The column on the relation table to match with + match string +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Field adds a relation field to the model +func (m *ModelConfig) Relation(name string) *modelRelationConfig { + relation := &modelRelationConfig{name: name} + + if m.relations == nil { + m.relations = make(map[string]*modelRelationConfig) + } + + m.relations[name] = relation + return relation +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Alias sets the alias for the column +func (r *modelRelationConfig) MatchOn(name string) *modelRelationConfig { + r.match = name + return r +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Left Join +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// modelLeftJoinConfig defines the configuration for a left join +type modelLeftJoinConfig struct { + // The name of the table to join with + table string + // The condition for the join, e.g. "table1.id = table2.id" + on string +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// LeftJoin adds a left join to the model +func (m *ModelConfig) LeftJoin(table string) *modelLeftJoinConfig { + join := &modelLeftJoinConfig{table: table} + m.leftJoins = append(m.leftJoins, join) + + return join +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// On sets the condition for the join +func (j *modelLeftJoinConfig) On(condition string) *modelLeftJoinConfig { + j.on = condition + return j +} diff --git a/utils/schema/relation.go b/utils/schema/relation.go new file mode 100644 index 0000000..4ffe6d2 --- /dev/null +++ b/utils/schema/relation.go @@ -0,0 +1,52 @@ +package schema + +import ( + "reflect" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +type relation struct { + // The name of the struct field + Name string + + // The position of the field in the struct + Position []int + + // True when the relation is a has many (i.e. a slice) + HasMany bool + + // The field to match on in the relation + MatchOn string + + // The type of the related struct + RelatedType reflect.Type + + // When true, the relation is a pointer + IsPtr bool +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// parseField will parse a field from the struct and return a field struct +func parseRelation(sf reflect.StructField, config *modelRelationConfig) *relation { + r := &relation{ + Name: sf.Name, + Position: sf.Index, + MatchOn: config.match, + RelatedType: sf.Type, + IsPtr: sf.Type.Kind() == reflect.Ptr, + } + + concreteType := sf.Type + for concreteType.Kind() == reflect.Ptr { + concreteType = concreteType.Elem() + } + + if concreteType.Kind() == reflect.Slice { + r.HasMany = true + r.RelatedType = concreteType.Elem() + } + + return r +} diff --git a/utils/schema/schema.go b/utils/schema/schema.go new file mode 100644 index 0000000..6a07fcd --- /dev/null +++ b/utils/schema/schema.go @@ -0,0 +1,805 @@ +package schema + +import ( + "database/sql" + "fmt" + "reflect" + "sync" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/utils" +) + +var cache = &sync.Map{} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Schema defines the structure of a model after parsing +type Schema struct { + // The name of the table in the database + Table string + + // The fields of the model + Fields []*field + + // FieldsByColumn is a map of fields by their DB column name + FieldsByColumn map[string]*field + + // A slice of relations + Relations []*relation + + // A slice of left joins + LeftJoins []string +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Parse parses a model +func Parse(model any) (*Schema, error) { + if model == nil { + return nil, utils.ErrNilPtr + } + + // Get the reflect value and unwrap pointers + rv, err := concreteReflectValue(reflect.ValueOf(model)) + if err != nil { + return nil, err + } + + rt := rv.Type() + + // If the model is a pointer, slice, or array, get the element type + for rt.Kind() == reflect.Slice || rt.Kind() == reflect.Array || rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + + // Attempt to load the Schema from cache + if v, ok := cache.Load(rt); ok { + s := v.(*Schema) + return s, nil + } + + // Error when the model does not implement the Modeler interface + modeler, isModeler := reflect.New(rt).Interface().(Modeler) + if !isModeler { + return nil, utils.ErrNotModeler + } + + s := &Schema{ + Table: modeler.Table(), + } + + config := &ModelConfig{} + modeler.Define(config) + + for i := range rt.NumField() { + sf := rt.Field(i) + + if fieldConfig, ok := config.fields[sf.Name]; ok { + s.Fields = append(s.Fields, parseField(sf, fieldConfig)) + } else if _, ok := config.embedded[sf.Name]; ok { + fields, err := parseEmbeddedField(sf) + if err != nil { + return nil, err + } + + s.Fields = append(s.Fields, fields...) + } else if relationConfig, ok := config.relations[sf.Name]; ok { + s.Relations = append(s.Relations, parseRelation(sf, relationConfig)) + } + } + + // Build the FieldsByColumn map + s.FieldsByColumn = make(map[string]*field, len(s.Fields)) + for _, f := range s.Fields { + if f.Alias != "" { + s.FieldsByColumn[f.Alias] = f + } else { + s.FieldsByColumn[f.Column] = f + } + } + + // Build the left joins + for _, join := range config.leftJoins { + s.LeftJoins = append(s.LeftJoins, fmt.Sprintf("%s ON %s", join.table, join.on)) + } + + // Store the schema in the cache + if v, loaded := cache.LoadOrStore(rt, s); loaded { + s := v.(*Schema) + return s, nil + } + + return s, nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// CALLERS +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Count calls the CountBuilder and executes the query, returning the count +func (s *Schema) Count(options *database.Options, db database.Querier) (int, error) { + query, args, _ := s.CountBuilder(options).ToSql() + + var count int + err := db.QueryRow(query, args...).Scan(&count) + return count, err +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Insert calls the InsertBuilder and executes the query, inserting a row +func (s *Schema) Insert(model any, db database.Querier) (sql.Result, error) { + query, args, _ := s.InsertBuilder(model).ToSql() + return db.Exec(query, args...) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Select calls the SelectBuilder and executes the query, scanning the result into the model, +// which may be a struct or a slice of structs +func (s *Schema) Select(model any, options *database.Options, db database.Querier) error { + rv := reflect.ValueOf(model) + + if rv.Kind() != reflect.Ptr { + return utils.ErrNotPtr + } + + if rv.IsNil() { + return utils.ErrNilPtr + } + + var err error + var rows Rows + + concreteRv, err := concreteReflectValue(reflect.ValueOf(model)) + if err != nil { + return err + } + + if concreteRv.Kind() == reflect.Slice { + query, args, _ := s.SelectBuilder(options).ToSql() + + rows, err = db.Query(query, args...) + if err != nil { + return err + } + defer rows.Close() + + err = s.ScanMany(rows, rv) + if err != nil { + return err + } + + err = s.loadRelationsMany(concreteRv, db) + } else { + query, args, _ := s.SelectBuilder(options).Limit(1).ToSql() + rows, err = db.Query(query, args...) + if err != nil { + return err + } + defer rows.Close() + + err = s.ScanOne(rows, rv) + if err != nil { + return err + } + + err = s.loadRelationsOne(concreteRv, db) + } + + return err +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Insert calls the InsertBuilder and executes the query, inserting a row +func (s *Schema) Update(model any, options *database.Options, db database.Querier) (sql.Result, error) { + builder, err := s.UpdateBuilder(model, options) + if err != nil { + return nil, err + } + + query, args, _ := builder.ToSql() + return db.Exec(query, args...) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Delete calls the DeleteBuilder and executes the query, deleting rows +func (s *Schema) Delete(options *database.Options, db database.Querier) (sql.Result, error) { + query, args, _ := s.DeleteBuilder(options).ToSql() + return db.Exec(query, args...) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// BUILDERS +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// CountBuilder creates a squirrel SelectBuilder for the model +func (s *Schema) CountBuilder(options *database.Options) squirrel.SelectBuilder { + builder := squirrel. + StatementBuilder. + PlaceholderFormat(squirrel.Question). + Select("COUNT(DISTINCT " + s.Table + ".id)"). + From(s.Table) + + if options != nil && options.Where != nil { + builder = builder.Where(options.Where) + } + + return builder +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// InsertBuilder creates a squirrel InsertBuilder for the model +func (s *Schema) InsertBuilder(model any) squirrel.InsertBuilder { + data := make(map[string]any, len(s.Fields)) + + for _, f := range s.Fields { + // Ignore fields that part of a join + if f.JoinTable != "" { + continue + } + + val, zero := f.ValueOf(reflect.ValueOf(model)) + + // When the field cannot be null and the value is zero, set the value to nil + if f.NotNull && zero { + if f.IgnoreIfNull { + continue + } + + data[f.Column] = nil + } else { + data[f.Column] = val + } + } + + builder := squirrel. + StatementBuilder. + Insert(s.Table). + SetMap(data) + + return builder +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// SelectBuilder creates a squirrel SelectBuilder for the model +func (s *Schema) SelectBuilder(options *database.Options) squirrel.SelectBuilder { + builder := squirrel. + StatementBuilder. + PlaceholderFormat(squirrel.Question). + Select(""). + From(s.Table). + RemoveColumns() + + for _, f := range s.Fields { + table := s.Table + if f.JoinTable != "" { + table = f.JoinTable + } + + if f.Alias != "" { + builder = builder.Column(fmt.Sprintf("%s.%s AS %s", table, f.Column, f.Alias)) + } else { + builder = builder.Column(fmt.Sprintf("%s.%s", table, f.Column)) + } + } + + for _, join := range s.LeftJoins { + builder = builder.LeftJoin(join) + } + + if options != nil { + builder = builder.Where(options.Where). + OrderBy(options.OrderBy...). + GroupBy(options.GroupBy...) + + if options.Pagination != nil { + builder = builder. + Offset(uint64(options.Pagination.Offset())). + Limit(uint64(options.Pagination.Limit())) + } + } + + return builder +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// UpdateBuilder creates a squirrel UpdateBuilder for the model +func (s *Schema) UpdateBuilder(model any, options *database.Options) (squirrel.UpdateBuilder, error) { + if options == nil || options.Where == nil { + return squirrel.UpdateBuilder{}, utils.ErrInvalidWhere + } + + builder := squirrel. + StatementBuilder. + Update(s.Table). + Where(options.Where) + + for _, f := range s.Fields { + if f.JoinTable != "" || !f.Mutable { + continue + } + + val, zero := f.ValueOf(reflect.ValueOf(model)) + + if f.NotNull && zero { + if f.IgnoreIfNull { + continue + } + + val = nil + } + + builder = builder.Set(f.Column, val) + } + + return builder, nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// DeleteBuilder creates a squirrel DeleteBuilder for the model +func (s *Schema) DeleteBuilder(options *database.Options) squirrel.DeleteBuilder { + builder := squirrel. + StatementBuilder. + PlaceholderFormat(squirrel.Question). + Delete(s.Table) + + if options != nil && options.Where != nil { + builder = builder.Where(options.Where) + } + + return builder +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// SCANS +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Rows defines the interface for rows +type Rows interface { + Columns() ([]string, error) + Next() bool + Scan(dest ...interface{}) error + Err() error + Close() error +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Scan scans the rows into the model, which can be a pointer to a slice or a single struct +func (s *Schema) Scan(rows Rows, model any) error { + rv := reflect.ValueOf(model) + defer rows.Close() + + if rv.Kind() != reflect.Ptr { + return utils.ErrNotPtr + } + + if rv.IsNil() { + return utils.ErrNilPtr + } + + concreteValue := reflect.Indirect(rv) + + if concreteValue.Kind() == reflect.Slice { + concreteValue.SetLen(0) + + isPtr := concreteValue.Type().Elem().Kind() == reflect.Ptr + + base := concreteValue.Type().Elem() + if isPtr { + base = base.Elem() + } + + columns, err := rows.Columns() + if err != nil { + return err + } + + // Create a pointer of values for each field + values := make([]interface{}, len(columns)) + + for rows.Next() { + instance := reflect.New(base) + concreteInstance := reflect.Indirect(instance) + + for idx, column := range columns { + if field := s.FieldsByColumn[column]; field != nil { + v := concreteInstance + for _, pos := range field.Position { + // TODO - If value is a pointer and nil, initialize it + // TODO - If value is a map and nil, initialize it + v = reflect.Indirect(v).Field(pos) + } + + values[idx] = v.Addr().Interface() + } else { + return fmt.Errorf("column %s not found in model", column) + } + } + + err = rows.Scan(values...) + if err != nil { + return err + } + + if isPtr { + concreteValue.Set(reflect.Append(concreteValue, instance)) + } else { + concreteValue.Set(reflect.Append(concreteValue, concreteInstance)) + } + } + } else { + columns, err := rows.Columns() + if err != nil { + return err + } + + // Create a pointer of values for each field + values := make([]interface{}, len(columns)) + + for idx, column := range columns { + if field := s.FieldsByColumn[column]; field != nil { + v := rv + for _, pos := range field.Position { + // TODO - If value is a pointer and nil, initialize it + // TODO - If value is a map and nil, initialize it + v = reflect.Indirect(v).Field(pos) + } + + values[idx] = v.Addr().Interface() + } else { + return fmt.Errorf("column %s not found in model", column) + } + } + + if !rows.Next() { + return sql.ErrNoRows + } + + err = rows.Scan(values...) + if err != nil { + return err + } + } + + return nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func (s *Schema) ScanMany(rows Rows, rv reflect.Value) error { + defer rows.Close() + + if rv.Kind() != reflect.Ptr { + return utils.ErrNotPtr + } + + if rv.IsNil() { + return utils.ErrNilPtr + } + + concreteRv, err := concreteReflectValue(rv) + if err != nil { + return err + } + + if concreteRv.Kind() != reflect.Slice { + return utils.ErrNotSlice + } + + concreteRv.SetLen(0) + + isPtr := concreteRv.Type().Elem().Kind() == reflect.Ptr + + base := concreteRv.Type().Elem() + if isPtr { + base = base.Elem() + } + + columns, err := rows.Columns() + if err != nil { + return err + } + + // Create a pointer of values for each field + values := make([]interface{}, len(columns)) + + for rows.Next() { + instance := reflect.New(base) + concreteInstance := reflect.Indirect(instance) + + for idx, column := range columns { + if field := s.FieldsByColumn[column]; field != nil { + v := concreteInstance + for _, pos := range field.Position { + // TODO - If value is a pointer and nil, initialize it + // TODO - If value is a map and nil, initialize it + v = reflect.Indirect(v).Field(pos) + } + + values[idx] = v.Addr().Interface() + } else { + return fmt.Errorf("column %s not found in model", column) + } + } + + err = rows.Scan(values...) + if err != nil { + return err + } + + if isPtr { + concreteRv.Set(reflect.Append(concreteRv, instance)) + } else { + concreteRv.Set(reflect.Append(concreteRv, concreteInstance)) + } + } + + return nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// ScanOne scans a single row into the model +func (s *Schema) ScanOne(rows Rows, rv reflect.Value) error { + defer rows.Close() + + if rv.Kind() != reflect.Ptr { + return utils.ErrNotPtr + } + + if rv.IsNil() { + return utils.ErrNilPtr + } + + concreteRv, err := concreteReflectValue(rv) + if err != nil { + return err + } + + if concreteRv.Kind() != reflect.Struct { + return utils.ErrNotStruct + } + + columns, err := rows.Columns() + if err != nil { + return err + } + + // Create a pointer of values for each field + values := make([]interface{}, len(columns)) + + for idx, column := range columns { + if field := s.FieldsByColumn[column]; field != nil { + v := rv + + for _, pos := range field.Position { + // TODO - If value is a pointer and nil, initialize it + // TODO - If value is a map and nil, initialize it + v = reflect.Indirect(v).Field(pos) + } + + values[idx] = v.Addr().Interface() + } else { + return fmt.Errorf("column %s not found in model", column) + } + } + + if !rows.Next() { + return sql.ErrNoRows + } + + err = rows.Scan(values...) + if err != nil { + return err + } + + return nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// RELATIONS +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// loadRelationsOne loads the relations for a single model, handling both one-to-one and +// one-to-many relationships +func (s *Schema) loadRelationsOne(concreteRv reflect.Value, db database.Querier) error { + if concreteRv.Kind() != reflect.Struct { + return utils.ErrNotStruct + } + + // Get the value of the ID field + id, zero := s.FieldsByColumn["id"].ValueOf(concreteRv) + if zero { + return nil + } + + for _, rel := range s.Relations { + relatedSchema, relatedModelPtr, err := parseRelatedSchema(rel) + if err != nil { + return err + } + + // Get the field in the struct to set the related model on + structField := getStructField(concreteRv, rel.Position) + + // Create the options to select related rows + options := &database.Options{Where: squirrel.Eq{rel.MatchOn: id}} + + if rel.HasMany { + structFieldType := structField.Type() + var structFieldPtr reflect.Value + + if rel.IsPtr { + // Create a new pointer slice + elemType := structFieldType.Elem() + structField.Set(reflect.New(elemType)) + structField.Elem().Set(reflect.MakeSlice(elemType, 0, 0)) + structFieldPtr = structField.Elem().Addr() + } else { + // Create a new slice + structField.Set(reflect.MakeSlice(structFieldType, 0, 0)) + structFieldPtr = structField.Addr() + } + + err = relatedSchema.Select(structFieldPtr.Interface(), options, db) + if err != nil && err != sql.ErrNoRows { + return err + } + } else { + err = relatedSchema.Select(relatedModelPtr.Interface(), options, db) + if err != nil { + if err == sql.ErrNoRows { + continue + } + + return err + } + + setRelatedField(structField, relatedModelPtr) + } + } + + return nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// loadRelationsMany loads the relations for a slice of models, handling both many-to-one and +// many-to-many relationships +func (s *Schema) loadRelationsMany(concreteRv reflect.Value, db database.Querier) error { + if concreteRv.Kind() != reflect.Slice { + return utils.ErrNotSlice + } + + // If the slice is empty, return + if concreteRv.Len() == 0 { + return nil + } + + // Get the IDs of all the models so we only do 1 query per relation + ids := []any{} + for i := 0; i < concreteRv.Len(); i++ { + v, zero := s.FieldsByColumn["id"].ValueOf(concreteRv.Index(i)) + if zero { + continue + } + ids = append(ids, v) + } + + for _, rel := range s.Relations { + relatedSchema, _, err := parseRelatedSchema(rel) + if err != nil { + return err + } + + // Create a slice to hold the related models + relatedSlicePtr := reflect.New(reflect.SliceOf(rel.RelatedType)) + relatedSlice := relatedSlicePtr.Interface() + + // Create the options to select related rows + options := &database.Options{Where: squirrel.Eq{rel.MatchOn: ids}} + err = relatedSchema.Select(relatedSlice, options, db) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + + return err + } + + relatedSliceValue := reflect.Indirect(reflect.ValueOf(relatedSlice)) + + if rel.HasMany { + // --- MANY TO MANY --- + + // Create a map of where the key is the MatchOn value and the value is a slice of + // related items + resultMap := make(map[any][]reflect.Value) + for i := 0; i < relatedSliceValue.Len(); i++ { + item := relatedSliceValue.Index(i) + id, _ := relatedSchema.FieldsByColumn[rel.MatchOn].ValueOf(reflect.Indirect(item)) + resultMap[id] = append(resultMap[id], item) + } + + for i := 0; i < concreteRv.Len(); i++ { + sliceItem := concreteRv.Index(i) + + v, zero := s.FieldsByColumn["id"].ValueOf(sliceItem) + if zero { + continue + } + + relatedItems, found := resultMap[v] + if !found { + continue + } + + // Get the field in the struct to set result + structField := getStructField(sliceItem, rel.Position) + structFieldType := structField.Type() + + // Create a new slice that will hold the related items, based upon whether it is a + // slice of a pointer slice + var newRelatedSlice reflect.Value + if structFieldType.Kind() == reflect.Ptr { + newRelatedSlice = reflect.MakeSlice(structFieldType.Elem(), 0, len(relatedItems)) + } else { + newRelatedSlice = reflect.MakeSlice(structFieldType, 0, len(relatedItems)) + } + + // Append the related items to the new slice + for _, item := range relatedItems { + newRelatedSlice = reflect.Append(newRelatedSlice, item) + } + + if structField.Kind() == reflect.Ptr { + structField.Set(reflect.New(structFieldType.Elem())) + structField.Elem().Set(newRelatedSlice) + } else { + structField.Set(newRelatedSlice) + } + } + } else { + // --- MANY TO ONE --- + + // Create a map of related items where the key is the MatchOn value + relatedMap := make(map[any]reflect.Value) + for i := 0; i < relatedSliceValue.Len(); i++ { + item := relatedSliceValue.Index(i) + id, _ := relatedSchema.FieldsByColumn[rel.MatchOn].ValueOf(reflect.Indirect(item)) + relatedMap[id] = item + } + + // Set the related items on the model + for i := 0; i < concreteRv.Len(); i++ { + sliceItem := concreteRv.Index(i) + + v, zero := s.FieldsByColumn["id"].ValueOf(sliceItem) + if zero { + continue + } + + relatedItem, found := relatedMap[v] + if !found { + continue + } + + // Get the field in the struct and set the value + relatedField := getStructField(sliceItem, rel.Position) + setRelatedField(relatedField, relatedItem) + } + } + } + + return nil +} diff --git a/utils/schema/schema_test.go b/utils/schema/schema_test.go new file mode 100644 index 0000000..fca162e --- /dev/null +++ b/utils/schema/schema_test.go @@ -0,0 +1,313 @@ +package schema + +import ( + "database/sql" + "testing" + + "github.com/Masterminds/squirrel" + "github.com/geerew/off-course/database" + "github.com/geerew/off-course/utils" + _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_Parse(t *testing.T) { + t.Run("struct", func(t *testing.T) { + sch, err := Parse(&TestUser{}) + require.NoError(t, err) + require.NotNil(t, sch) + + require.Equal(t, "users", sch.Table) + require.Len(t, sch.Fields, 1) + require.Len(t, sch.Relations, 3) + }) + + t.Run("slice", func(t *testing.T) { + var users []*TestUser + schema, err := Parse(users) + require.NotNil(t, schema) + require.NoError(t, err) + }) + + t.Run("nil", func(t *testing.T) { + schema, err := Parse(nil) + require.Nil(t, schema) + require.ErrorIs(t, err, utils.ErrNilPtr) + }) + + t.Run("nil struct", func(t *testing.T) { + var user *TestUser + schema, err := Parse(user) + require.Nil(t, schema) + require.ErrorIs(t, err, utils.ErrInvalidValue) + }) + + t.Run("not a modeler", func(t *testing.T) { + schema, err := Parse(&struct{}{}) + require.Nil(t, schema) + require.ErrorIs(t, err, utils.ErrNotModeler) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_Select(t *testing.T) { + t.Run("struct success", func(t *testing.T) { + db := setup(t) + + sch, err := Parse(&TestUser{}) + require.NoError(t, err) + require.NotNil(t, sch) + + u := &TestUser{} + err = sch.Select(u, &database.Options{Where: squirrel.Eq{"id": 1}}, db) + require.NoError(t, err) + + require.Equal(t, 1, u.ID) + require.Equal(t, 1, u.Profile.ID) + require.Len(t, u.Posts, 2) + require.Equal(t, "Post 1 by John", u.Posts[0].Title) + require.Equal(t, "Post 2 by John", u.Posts[1].Title) + require.Len(t, *u.PtrPosts, 2) + }) + + t.Run("slice success", func(t *testing.T) { + db := setup(t) + + sch, err := Parse(&TestUser{}) + require.NoError(t, err) + require.NotNil(t, sch) + + u := []TestUser{} + err = sch.Select(&u, nil, db) + require.NoError(t, err) + require.Len(t, u, 2) + + require.Equal(t, 1, u[0].ID) + require.Equal(t, 1, u[0].Profile.ID) + require.Len(t, u[0].Posts, 2) + require.Equal(t, "Post 1 by John", u[0].Posts[0].Title) + require.Equal(t, "Post 2 by John", u[0].Posts[1].Title) + require.Len(t, *u[0].PtrPosts, 2) + + require.Equal(t, 2, u[1].ID) + require.Equal(t, 2, u[1].Profile.ID) + require.Len(t, u[1].Posts, 1) + require.Equal(t, "Post by Jane", u[1].Posts[0].Title) + require.Len(t, *u[1].PtrPosts, 1) + }) + + t.Run("no relation found", func(t *testing.T) { + db := setup(t) + + userSchema, err := Parse(&TestUser{}) + require.NoError(t, err) + require.NotNil(t, userSchema) + + profileSchema, err := Parse(&TestProfile{}) + require.NoError(t, err) + require.NotNil(t, profileSchema) + + postSchema, err := Parse(&TestPost{}) + require.NoError(t, err) + require.NotNil(t, postSchema) + + // Delete all profiles + _, err = profileSchema.Delete(nil, db) + require.NoError(t, err) + + // Delete all posts + _, err = postSchema.Delete(nil, db) + require.NoError(t, err) + + u := &TestUser{} + err = userSchema.Select(u, &database.Options{Where: squirrel.Eq{"id": 1}}, db) + require.NoError(t, err) + + require.Equal(t, 1, u.ID) + require.Equal(t, TestProfile{}, u.Profile) + require.Len(t, u.Posts, 0) + }) + + t.Run("some relations found", func(t *testing.T) { + db := setup(t) + + userSchema, err := Parse(&TestUser{}) + require.NoError(t, err) + require.NotNil(t, userSchema) + + profileSchema, err := Parse(&TestProfile{}) + require.NoError(t, err) + require.NotNil(t, profileSchema) + + postSchema, err := Parse(&TestPost{}) + require.NoError(t, err) + require.NotNil(t, postSchema) + + // Delete Johns profile + _, err = profileSchema.Delete(&database.Options{Where: squirrel.Eq{"id": 1}}, db) + require.NoError(t, err) + + // Delete Johns posts + _, err = postSchema.Delete(&database.Options{Where: squirrel.Eq{"user_id": 1}}, db) + require.NoError(t, err) + + u := []*TestUser{} + err = userSchema.Select(&u, nil, db) + require.NoError(t, err) + + require.Len(t, u, 2) + require.Equal(t, 1, u[0].ID) + require.Equal(t, TestProfile{}, u[0].Profile) + require.Len(t, u[0].Posts, 0) + + require.Equal(t, 2, u[1].ID) + require.Equal(t, 2, u[1].Profile.ID) + require.Len(t, u[1].Posts, 1) + }) + + t.Run("no rows", func(t *testing.T) { + db := setup(t) + + sch, err := Parse(&TestUser{}) + require.NoError(t, err) + require.NotNil(t, sch) + + // Delete everything from the users table + _, err = sch.Delete(nil, db) + require.NoError(t, err) + + u := &TestUser{} + err = sch.Select(u, &database.Options{Where: squirrel.Eq{"id": 1}}, db) + require.ErrorIs(t, err, sql.ErrNoRows) + }) + + t.Run("not a pointer", func(t *testing.T) { + db := setup(t) + + sch, err := Parse(&TestUser{}) + require.NoError(t, err) + require.NotNil(t, sch) + + err = sch.Select(TestUser{}, nil, db) + require.ErrorIs(t, err, utils.ErrNotPtr) + }) + + t.Run("nil pointer", func(t *testing.T) { + db := setup(t) + + sch, err := Parse(&TestUser{}) + require.NoError(t, err) + require.NotNil(t, sch) + + var u *TestUser + err = sch.Select(u, nil, db) + require.ErrorIs(t, err, utils.ErrNilPtr) + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_Count(t *testing.T) { + t.Run("success", func(t *testing.T) { + db := setup(t) + + sch, err := Parse(&TestUser{}) + require.NoError(t, err) + require.NotNil(t, sch) + + count, err := sch.Count(&database.Options{Where: squirrel.Eq{"id": 1}}, db) + require.NoError(t, err) + require.Equal(t, 1, count) + }) + + t.Run("no rows", func(t *testing.T) { + db := setup(t) + + sch, err := Parse(&TestUser{}) + require.NoError(t, err) + require.NotNil(t, sch) + + // Delete everything from the users table + _, err = sch.Delete(nil, db) + require.NoError(t, err) + + count, err := sch.Count(nil, db) + require.NoError(t, err) + require.Zero(t, count) + + }) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Benchmark_Create(b *testing.B) { + db := setup(b) + + userSchema, err := Parse(&TestUser{}) + require.NoError(b, err) + require.NotNil(b, userSchema) + + _, err = userSchema.Delete(nil, db) + require.NoError(b, err) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + sch, err := Parse(&TestUser{}) + require.NoError(b, err) + + u := &TestUser{ + TestBase: TestBase{ + ID: i, + }, + } + + builder := sch.InsertBuilder(u) + query, args, _ := builder.ToSql() + + _, err = db.Exec(query, args...) + require.NoError(b, err) + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Benchmark_Scan(b *testing.B) { + db := setup(b) + + userSchema, err := Parse(&TestUser{}) + require.NoError(b, err) + require.NotNil(b, userSchema) + + // Empty users + _, err = userSchema.Delete(nil, db) + require.NoError(b, err) + + profileSchema, err := Parse(&TestProfile{}) + require.NoError(b, err) + require.NotNil(b, profileSchema) + + // Empty profiles + _, err = profileSchema.Delete(nil, db) + require.NoError(b, err) + + // Insert 1000 users and profiles + for i := 0; i < 1000; i++ { + _, err = db.Exec(`INSERT INTO users (id) VALUES (?)`, i) + require.NoError(b, err) + + _, err = db.Exec(`INSERT INTO profiles (user_id, name, username, email) VALUES (?, ?, ?, ?)`, i, "John Doe", "johndoe", "john@test.com") + require.NoError(b, err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + u := &TestUser{} + err = userSchema.Select(u, &database.Options{Where: squirrel.Eq{"id": (i % 1000)}}, db) + require.NoError(b, err) + } +} diff --git a/utils/schema/utils.go b/utils/schema/utils.go new file mode 100644 index 0000000..a37613b --- /dev/null +++ b/utils/schema/utils.go @@ -0,0 +1,78 @@ +package schema + +import ( + "reflect" + + "github.com/geerew/off-course/utils" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// parseRelatedSchema parses the related schema and returns a reflect value that can +// be used to set the related field +func parseRelatedSchema(rel *relation) (*Schema, reflect.Value, error) { + concreteType := rel.RelatedType + for concreteType.Kind() == reflect.Ptr { + concreteType = concreteType.Elem() + } + + relatedModelPtr := reflect.New(concreteType) + relatedModel := relatedModelPtr.Interface() + + relatedSchema, err := Parse(relatedModel) + if err != nil { + return nil, reflect.Value{}, err + } + + return relatedSchema, relatedModelPtr, nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// getStructField returns a struct field at the given position +func getStructField(concreteRv reflect.Value, position []int) reflect.Value { + structField := concreteRv + for _, pos := range position { + structField = reflect.Indirect(structField).Field(pos) + } + + return structField +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// setRelatedField sets the related field +func setRelatedField(relatedField reflect.Value, value reflect.Value) { + if relatedField.Kind() == reflect.Ptr { + if value.Kind() == reflect.Ptr { + relatedField.Set(value) + } else { + relatedField.Set(value.Addr()) + } + } else { + if value.Kind() == reflect.Ptr { + relatedField.Set(value.Elem()) + } else { + relatedField.Set(value) + } + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// concreteReflectValue returns the concrete reflect value +func concreteReflectValue(v reflect.Value) (reflect.Value, error) { + for v.Kind() == reflect.Ptr { + if v.IsNil() && v.CanAddr() { + v.Set(reflect.New(v.Type().Elem())) + } + + v = v.Elem() + } + + if !v.IsValid() { + return v, utils.ErrInvalidValue + } + + return v, nil +} diff --git a/utils/schema/utils_test.go b/utils/schema/utils_test.go new file mode 100644 index 0000000..9ac5472 --- /dev/null +++ b/utils/schema/utils_test.go @@ -0,0 +1,61 @@ +package schema + +import ( + "reflect" + "testing" + + "github.com/geerew/off-course/utils" + "github.com/stretchr/testify/require" +) + +// ------------------------------------------------------- + +func Test_ConcreteReflectValue(t *testing.T) { + t.Run("non-pointer", func(t *testing.T) { + i := 42 + v := reflect.ValueOf(i) + + result, err := concreteReflectValue(v) + require.NoError(t, err) + require.Equal(t, v, result) + require.Equal(t, reflect.Int, result.Kind()) + }) + + t.Run("pointer", func(t *testing.T) { + i := 42 + v := reflect.ValueOf(&i) + + result, err := concreteReflectValue(v) + require.NoError(t, err) + require.Equal(t, reflect.Int, result.Kind()) + }) + + t.Run("nested pointer", func(t *testing.T) { + i := 42 + ptr := &i + ptrPtr := &ptr + v := reflect.ValueOf(&ptrPtr) + + result, err := concreteReflectValue(v) + require.NoError(t, err) + require.Equal(t, reflect.Int, result.Kind()) + }) + + t.Run("nil pointer", func(t *testing.T) { + var i *int + v := reflect.ValueOf(i) + + result, err := concreteReflectValue(v) + require.ErrorIs(t, err, utils.ErrInvalidValue) + require.Equal(t, reflect.Invalid, result.Kind()) + }) + + t.Run("nested pointer nil", func(t *testing.T) { + var i **int + v := reflect.ValueOf(i) + + result, err := concreteReflectValue(v) + require.ErrorIs(t, err, utils.ErrInvalidValue) + require.Equal(t, reflect.Invalid, result.Kind()) + }) +} diff --git a/utils/types/asset.go b/utils/types/asset.go index bff93fb..9b5cb6d 100644 --- a/utils/types/asset.go +++ b/utils/types/asset.go @@ -53,7 +53,6 @@ func NewAsset(ext string) *Asset { return &Asset{s: AssetHTML} case "pdf": return &Asset{s: AssetPDF} - } return nil @@ -131,11 +130,7 @@ func (a *Asset) UnmarshalJSON(b []byte) error { // Value implements the `driver.Valuer` interface func (a Asset) Value() (driver.Value, error) { - if a.s == "" { - return nil, nil - } - - return a.s, nil + return a.String(), nil } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/utils/types/asset_test.go b/utils/types/asset_test.go index 5f904ac..60e4079 100644 --- a/utils/types/asset_test.go +++ b/utils/types/asset_test.go @@ -103,7 +103,7 @@ func TestAsset_MarshalJSON(t *testing.T) { require.NotNil(t, a) res, err := a.MarshalJSON() - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, tt.expected, string(res)) } } @@ -145,11 +145,11 @@ func TestAsset_UnmarshalJSON(t *testing.T) { func TestAsset_Value(t *testing.T) { tests := []struct { input string - expected AssetType + expected string }{ - {"mp4", AssetVideo}, - {"html", AssetHTML}, - {"pdf", AssetPDF}, + {"mp4", "video"}, + {"html", "html"}, + {"pdf", "pdf"}, } for _, tt := range tests { @@ -157,15 +157,15 @@ func TestAsset_Value(t *testing.T) { require.NotNil(t, a) res, err := a.Value() - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, tt.expected, res) } // Nil a := Asset{} res, err := a.Value() - require.Nil(t, err) - require.Nil(t, res) + require.NoError(t, err) + require.Empty(t, res) } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -185,7 +185,7 @@ func TestAsset_Scan(t *testing.T) { a := Asset{} err := a.Scan(tt.value) - require.Nil(t, err) + require.NoError(t, err) require.Contains(t, a.s, tt.expected) } }) diff --git a/utils/types/date_time.go b/utils/types/date_time.go new file mode 100644 index 0000000..9ac5dc0 --- /dev/null +++ b/utils/types/date_time.go @@ -0,0 +1,134 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "time" + + "github.com/spf13/cast" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// DefaultDateLayout specifies the default app date strings layout +const DefaultDateLayout = "2006-01-02 15:04:05.000Z" + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// NowDateTime returns new DateTime instance with the current local time +func NowDateTime() DateTime { + return DateTime{t: time.Now()} +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// ParseDateTime creates a new DateTime from the provided value +func ParseDateTime(value any) (DateTime, error) { + d := DateTime{} + err := d.Scan(value) + return d, err +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// DateTime represents a [time.Time] instance in UTC that is wrapped and serialized +// using the app default date layout +type DateTime struct { + t time.Time +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Time returns the internal [time.Time] instance +func (d DateTime) Time() time.Time { + return d.t +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// IsZero checks whether the current DateTime instance has zero time value +func (d DateTime) IsZero() bool { + return d.Time().IsZero() +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Equal checks if two DateTime instances represent the same point in time +func (d DateTime) Equal(other DateTime) bool { + return d.t.UTC().Truncate(time.Millisecond).Equal(other.t.UTC().Truncate(time.Millisecond)) + +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// String serializes the current DateTime instance into a formatted +// UTC date string +// +// A zero value is serialized to an empty string +func (d DateTime) String() string { + t := d.Time() + if t.IsZero() { + return "" + } + return t.UTC().Format(DefaultDateLayout) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// MarshalJSON implements the [json.Marshaler] interface +func (d DateTime) MarshalJSON() ([]byte, error) { + return []byte(`"` + d.String() + `"`), nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// UnmarshalJSON implements the [json.Unmarshaler] interface +func (d *DateTime) UnmarshalJSON(b []byte) error { + var raw string + if err := json.Unmarshal(b, &raw); err != nil { + return err + } + return d.Scan(raw) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Value implements the [driver.Valuer] interface +func (d DateTime) Value() (driver.Value, error) { + return d.String(), nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Scan implements [sql.Scanner] interface to scan the provided value +// into the current DateTime instance +func (d *DateTime) Scan(value any) error { + switch v := value.(type) { + case time.Time: + d.t = v + case DateTime: + d.t = v.Time() + case string: + if v == "" { + d.t = time.Time{} + } else { + t, err := time.Parse(DefaultDateLayout, v) + if err != nil { + // check for other common date layouts + t = cast.ToTime(v) + } + d.t = t + } + case int, int64, int32, uint, uint64, uint32: + d.t = cast.ToTime(v) + default: + str := cast.ToString(v) + if str == "" { + d.t = time.Time{} + } else { + d.t = cast.ToTime(str) + } + } + + return nil +} diff --git a/utils/types/date_time_test.go b/utils/types/date_time_test.go new file mode 100644 index 0000000..c78e478 --- /dev/null +++ b/utils/types/date_time_test.go @@ -0,0 +1,207 @@ +package types + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_NowDateTime(t *testing.T) { + now := time.Now().UTC().Format("2006-01-02 15:04:05") // without ms part for test consistency + dt := NowDateTime() + + require.Contains(t, dt.String(), now) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_ParseDateTime(t *testing.T) { + nowTime := time.Now().UTC() + nowDateTime, _ := ParseDateTime(nowTime) + nowStr := nowTime.Format(DefaultDateLayout) + + scenarios := []struct { + value any + expected string + }{ + {nil, ""}, + {"", ""}, + {"invalid", ""}, + {nowDateTime, nowStr}, + {nowTime, nowStr}, + {1641024040, "2022-01-01 08:00:40.000Z"}, + {int32(1641024040), "2022-01-01 08:00:40.000Z"}, + {int64(1641024040), "2022-01-01 08:00:40.000Z"}, + {uint(1641024040), "2022-01-01 08:00:40.000Z"}, + {uint64(1641024040), "2022-01-01 08:00:40.000Z"}, + {uint32(1641024040), "2022-01-01 08:00:40.000Z"}, + {"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678Z"}, + } + + for i, s := range scenarios { + dt, err := ParseDateTime(s.value) + + require.Nil(t, err, "(%d) Failed to parse %v: %v", i, s.value, err) + require.Equal(t, s.expected, dt.String(), "(%d) Expected %q, got %q", i, s.expected, dt.String()) + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func Test_DateTimeTime(t *testing.T) { + str := "2022-01-01 11:23:45.678Z" + + expected, err := time.Parse(DefaultDateLayout, str) + require.NoError(t, err) + + dt, err := ParseDateTime(str) + require.NoError(t, err) + + result := dt.Time() + require.Equal(t, expected, result) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestDateTime_IsZero(t *testing.T) { + dt0 := DateTime{} + require.True(t, dt0.IsZero()) + + dt1 := NowDateTime() + require.False(t, dt1.IsZero()) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestDateTime_Equal(t *testing.T) { + scenarios := []struct { + dt1 DateTime + dt2 DateTime + expected bool + }{ + {DateTime{}, DateTime{}, true}, // Both zero values + {NowDateTime(), NowDateTime(), false}, // Different current times + {NowDateTime(), NowDateTime(), false}, // Another set of different times + { + dt1: DateTime{t: time.Date(2022, 1, 1, 11, 23, 45, 678000000, time.UTC)}, + dt2: DateTime{t: time.Date(2022, 1, 1, 11, 23, 45, 678000000, time.UTC)}, + expected: true, // Matching times + }, + { + dt1: DateTime{t: time.Date(2022, 1, 1, 11, 23, 45, 0, time.UTC)}, + dt2: DateTime{t: time.Date(2022, 1, 1, 11, 23, 46, 0, time.UTC)}, + expected: false, // Different times + }, + } + + for i, s := range scenarios { + require.True(t, s.dt1.Equal(s.dt1), "(%d) Expected %v.Equal(%v) to be true", i, s.dt1, s.dt1) + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestDateTime_String(t *testing.T) { + dt0 := DateTime{} + require.Empty(t, dt0.String()) + + expected := "2022-01-01 11:23:45.678Z" + dt1, _ := ParseDateTime(expected) + require.Equal(t, expected, dt1.String()) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestDateTime_MarshalJSON(t *testing.T) { + scenarios := []struct { + date string + expected string + }{ + {"", `""`}, + {"2022-01-01 11:23:45.678", `"2022-01-01 11:23:45.678Z"`}, + } + + for i, s := range scenarios { + dt, err := ParseDateTime(s.date) + require.Nil(t, err, "(%d) %v", i, err) + + result, err := dt.MarshalJSON() + require.Nil(t, err, "(%d) %v", i, err) + require.Equal(t, s.expected, string(result), "(%d) Expected %q, got %q", i, s.expected, string(result)) + + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestDateTime_UnmarshalJSON(t *testing.T) { + scenarios := []struct { + date string + expected string + }{ + {"", ""}, + {"invalid_json", ""}, + {"'123'", ""}, + {"2022-01-01 11:23:45.678", ""}, + {`"2022-01-01 11:23:45.678"`, "2022-01-01 11:23:45.678Z"}, + } + + for i, s := range scenarios { + dt := DateTime{} + dt.UnmarshalJSON([]byte(s.date)) + require.Equal(t, s.expected, dt.String(), "(%d) Expected %q, got %q", i, s.expected, dt.String()) + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestDateTime_Value(t *testing.T) { + scenarios := []struct { + value any + expected string + }{ + {"", ""}, + {"invalid", ""}, + {1641024040, "2022-01-01 08:00:40.000Z"}, + {"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678Z"}, + {NowDateTime(), NowDateTime().String()}, + } + + for i, s := range scenarios { + dt, _ := ParseDateTime(s.value) + result, err := dt.Value() + require.Nil(t, err, "(%d) %v", i, err) + require.Equal(t, s.expected, result, "(%d) Expected %q, got %q", i, s.expected, result) + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestDateTime_Scan(t *testing.T) { + now := time.Now().UTC().Format("2006-01-02 15:04:05") // without ms part for test consistency + + scenarios := []struct { + value any + expected string + }{ + {nil, ""}, + {"", ""}, + {"invalid", ""}, + {NowDateTime(), now}, + {time.Now(), now}, + {1.0, ""}, + {1641024040, "2022-01-01 08:00:40.000Z"}, + {"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678Z"}, + } + + for i, s := range scenarios { + dt := DateTime{} + + err := dt.Scan(s.value) + require.Nil(t, err, "(%d) %v", i, err) + require.Contains(t, dt.String(), s.expected, "(%d) Expected %q, got %q", i, s.expected, dt.String()) + } +} diff --git a/utils/types/log.go b/utils/types/log.go index bfe69cf..817bec0 100644 --- a/utils/types/log.go +++ b/utils/types/log.go @@ -11,7 +11,7 @@ type LogType int const ( LogTypeRequest LogType = iota LogTypeCron - LogTypeCourseScanner + LogTypeScanner LogTypeFileSystem LogTypeDB ) @@ -23,7 +23,7 @@ func AllLogTypes() []string { return []string{ LogTypeRequest.String(), LogTypeCron.String(), - LogTypeCourseScanner.String(), + LogTypeScanner.String(), LogTypeFileSystem.String(), LogTypeDB.String(), } @@ -33,7 +33,7 @@ func AllLogTypes() []string { // String returns the string representation of the LogType func (lt LogType) String() string { - names := [...]string{"request", "cron", "course scanner", "file system", "db"} + names := [...]string{"request", "cron", "scanner", "file system", "db"} if int(lt) < 0 || int(lt) >= len(names) { return "unknown" diff --git a/utils/types/log_test.go b/utils/types/log_test.go index 2228037..fa563ba 100644 --- a/utils/types/log_test.go +++ b/utils/types/log_test.go @@ -16,7 +16,7 @@ func TestLog_String(t *testing.T) { }{ {LogTypeRequest, "request"}, {LogTypeCron, "cron"}, - {LogTypeCourseScanner, "course scanner"}, + {LogTypeScanner, "scanner"}, {LogTypeFileSystem, "file system"}, {LogTypeDB, "db"}, {LogType(999), "unknown"}, @@ -32,7 +32,7 @@ func TestLog_String(t *testing.T) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ func TestLog_AllLogTypes(t *testing.T) { - expected := []string{"request", "cron", "course scanner", "file system", "db"} + expected := []string{"request", "cron", "scanner", "file system", "db"} require.Equal(t, expected, AllLogTypes()) } @@ -45,7 +45,7 @@ func TestLog_LogValue(t *testing.T) { }{ {LogTypeRequest, "request"}, {LogTypeCron, "cron"}, - {LogTypeCourseScanner, "course scanner"}, + {LogTypeScanner, "scanner"}, {LogTypeFileSystem, "file system"}, {LogTypeDB, "db"}, } diff --git a/utils/types/scan_status.go b/utils/types/scan_status.go index e3b1f28..2b8686c 100644 --- a/utils/types/scan_status.go +++ b/utils/types/scan_status.go @@ -27,21 +27,22 @@ const ( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// NewScanStatus creates a ScanStatus type with the status of +// NewScanStatusWaiting creates a ScanStatus type with the status of // waiting -func NewScanStatus(s ScanStatusType) ScanStatus { - switch s { - case ScanStatusWaiting: - return ScanStatus{s: ScanStatusWaiting} - case ScanStatusProcessing: - return ScanStatus{s: ScanStatusProcessing} - } - +func NewScanStatusWaiting() ScanStatus { return ScanStatus{s: ScanStatusWaiting} } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// NewScanStatusProcessing creates a ScanStatus type with the status of +// processing +func NewScanStatusProcessing() ScanStatus { + return ScanStatus{s: ScanStatusProcessing} +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // SetWaiting updates the state to waiting func (ss *ScanStatus) SetWaiting() { ss.s = ScanStatusWaiting @@ -63,6 +64,13 @@ func (ss ScanStatus) IsWaiting() bool { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// IsProcessing returns true is the status is process +func (ss ScanStatus) IsProcessing() bool { + return ss.s == ScanStatusProcessing +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // String implements the `Stringer` interface func (ss ScanStatus) String() string { return fmt.Sprint(ss.s) @@ -98,7 +106,6 @@ func (ss ScanStatus) Value() (driver.Value, error) { // Scan implements `sql.Scanner` interface func (ss *ScanStatus) Scan(value any) error { - vv := cast.ToString(value) switch vv { @@ -107,8 +114,7 @@ func (ss *ScanStatus) Scan(value any) error { case string(ScanStatusProcessing): ss.s = ScanStatusProcessing default: - // Default to waiting - ss.s = ScanStatusWaiting + ss.s = "" } return nil diff --git a/utils/types/scan_status_test.go b/utils/types/scan_status_test.go index 0cf980e..3346f1f 100644 --- a/utils/types/scan_status_test.go +++ b/utils/types/scan_status_test.go @@ -8,29 +8,20 @@ import ( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -func TestScanStatus_NewScanStatus(t *testing.T) { +func TestScanStatus_NewScanStatusWaiting(t *testing.T) { + require.Equal(t, ScanStatusWaiting, NewScanStatusWaiting().s) +} - tests := []struct { - input ScanStatusType - expected ScanStatusType - }{ - {ScanStatusWaiting, ScanStatusWaiting}, - {ScanStatusProcessing, ScanStatusProcessing}, - {"sdf", ScanStatusWaiting}, - } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - for _, tt := range tests { - s := NewScanStatus(tt.input) - require.Equal(t, tt.expected, s.s) - } +func TestScanStatus_NewScanStatusProcessing(t *testing.T) { + require.Equal(t, ScanStatusProcessing, NewScanStatusProcessing().s) } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ func TestScanStatus_SetWaiting(t *testing.T) { - s := NewScanStatus(ScanStatusProcessing) - require.Equal(t, ScanStatusProcessing, s.s) - + s := NewScanStatusProcessing() s.SetWaiting() require.Equal(t, ScanStatusWaiting, s.s) } @@ -38,9 +29,7 @@ func TestScanStatus_SetWaiting(t *testing.T) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ func TestScanStatus_SetProcessing(t *testing.T) { - s := NewScanStatus(ScanStatusWaiting) - require.Equal(t, ScanStatusWaiting, s.s) - + s := NewScanStatusWaiting() s.SetProcessing() require.Equal(t, ScanStatusProcessing, s.s) } @@ -48,38 +37,36 @@ func TestScanStatus_SetProcessing(t *testing.T) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ func TestScanStatus_IsWaiting(t *testing.T) { - tests := []struct { - input ScanStatusType - expected bool - }{ - {ScanStatusWaiting, true}, - {ScanStatusProcessing, false}, - } - - for _, tt := range tests { - s := NewScanStatus(tt.input) - require.Equal(t, tt.expected, s.IsWaiting()) - } + require.True(t, NewScanStatusWaiting().IsWaiting()) + require.False(t, NewScanStatusProcessing().IsWaiting()) + require.False(t, ScanStatus{}.IsWaiting()) } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -func TestScanStatus_MarshalJSON(t *testing.T) { - tests := []struct { - input ScanStatusType - expected string - }{ - {ScanStatusWaiting, `"waiting"`}, - {ScanStatusProcessing, `"processing"`}, - } +func TestScanStatus_IsProcess(t *testing.T) { + require.False(t, NewScanStatusWaiting().IsProcessing()) + require.True(t, NewScanStatusProcessing().IsProcessing()) + require.False(t, ScanStatus{}.IsProcessing()) +} - for _, tt := range tests { - s := NewScanStatus(tt.input) +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - res, err := s.MarshalJSON() - require.Nil(t, err) - require.Equal(t, tt.expected, string(res)) - } +func TestScanStatus_MarshalJSON(t *testing.T) { + waiting := NewScanStatusWaiting() + res, err := waiting.MarshalJSON() + require.NoError(t, err) + require.Equal(t, `"waiting"`, string(res)) + + processing := NewScanStatusProcessing() + res, err = processing.MarshalJSON() + require.NoError(t, err) + require.Equal(t, `"processing"`, string(res)) + + empty := ScanStatus{} + res, err = empty.MarshalJSON() + require.NoError(t, err) + require.Equal(t, `""`, string(res)) } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -94,8 +81,8 @@ func TestScanStatus_UnmarshalJSON(t *testing.T) { {"", "", "unexpected end of JSON input"}, {"xxx", "", "invalid character 'x' looking for beginning of value"}, // Defaults - {`""`, ScanStatusWaiting, ""}, - {`"bob"`, ScanStatusWaiting, ""}, + {`""`, "", ""}, + {`"bob"`, "", ""}, // Success {`"waiting"`, ScanStatusWaiting, ""}, {`"processing"`, ScanStatusProcessing, ""}, @@ -116,21 +103,20 @@ func TestScanStatus_UnmarshalJSON(t *testing.T) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ func TestScanStatus_Value(t *testing.T) { - tests := []struct { - input ScanStatusType - expected string - }{ - {ScanStatusWaiting, "waiting"}, - {ScanStatusProcessing, "processing"}, - } - - for _, tt := range tests { - s := NewScanStatus(tt.input) - - res, err := s.Value() - require.Nil(t, err) - require.Equal(t, tt.expected, res) - } + waiting := NewScanStatusWaiting() + res, err := waiting.Value() + require.NoError(t, err) + require.Equal(t, "waiting", res) + + processing := NewScanStatusProcessing() + res, err = processing.Value() + require.NoError(t, err) + require.Equal(t, "processing", res) + + empty := ScanStatus{} + res, err = empty.Value() + require.NoError(t, err) + require.Equal(t, "", res) } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -140,9 +126,11 @@ func TestScanStatus_Scan(t *testing.T) { value any expected string }{ - {nil, "waiting"}, - {"", "waiting"}, - {"invalid", "waiting"}, + // defaults + {nil, ""}, + {"", ""}, + {"invalid", ""}, + // Values {"waiting", "waiting"}, {"processing", "processing"}, } @@ -151,7 +139,7 @@ func TestScanStatus_Scan(t *testing.T) { ss := ScanStatus{} err := ss.Scan(tt.value) - require.Nil(t, err) + require.NoError(t, err) require.Contains(t, ss.s, tt.expected) } } diff --git a/utils/types/user_role.go b/utils/types/user_role.go new file mode 100644 index 0000000..d174909 --- /dev/null +++ b/utils/types/user_role.go @@ -0,0 +1,94 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "fmt" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// UserRole defines the possible roles for a user +type UserRole string + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +const ( + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// IsValid checks if the role is valid +func (r UserRole) IsValid() bool { + switch r { + case UserRoleAdmin, UserRoleUser: + return true + } + return false +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// String implements the Stringer interface +func (r UserRole) String() string { + return string(r) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// MarshalJSON implements the json.Marshaler interface +func (r UserRole) MarshalJSON() ([]byte, error) { + if !r.IsValid() { + return nil, fmt.Errorf("invalid user role: %s", r) + } + return json.Marshal(string(r)) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// UnmarshalJSON implements the json.Unmarshaler interface +func (r *UserRole) UnmarshalJSON(data []byte) error { + var role string + if err := json.Unmarshal(data, &role); err != nil { + return err + } + + userRole := UserRole(role) + if !userRole.IsValid() { + return fmt.Errorf("invalid user role: %s", role) + } + + *r = userRole + return nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Value implements the driver.Valuer interface for database serialization +func (r UserRole) Value() (driver.Value, error) { + if !r.IsValid() { + return nil, fmt.Errorf("invalid user role: %s", r) + } + return string(r), nil +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Scan implements the sql.Scanner interface +func (r *UserRole) Scan(value interface{}) error { + role, ok := value.(string) + if !ok { + return errors.New("invalid data type for UserRole") + } + + userRole := UserRole(role) + if !userRole.IsValid() { + return fmt.Errorf("invalid user role: %s", role) + } + + *r = userRole + return nil +} diff --git a/utils/types/user_role_test.go b/utils/types/user_role_test.go new file mode 100644 index 0000000..5c2f6fa --- /dev/null +++ b/utils/types/user_role_test.go @@ -0,0 +1,142 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestUserRole_String(t *testing.T) { + assert.Equal(t, "admin", UserRoleAdmin.String()) + assert.Equal(t, "user", UserRoleUser.String()) + assert.Equal(t, "invalid", UserRole("invalid").String()) +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestUserRole_IsValid(t *testing.T) { + tests := []struct { + role UserRole + expected bool + }{ + {UserRoleAdmin, true}, + {UserRoleUser, true}, + {UserRole("invalid"), false}, + } + + for _, tt := range tests { + t.Run(string(tt.role), func(t *testing.T) { + assert.Equal(t, tt.expected, tt.role.IsValid()) + }) + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestUserRole_MarshalJSON(t *testing.T) { + tests := []struct { + role UserRole + expected string + hasError bool + }{ + {UserRoleAdmin, `"admin"`, false}, + {UserRoleUser, `"user"`, false}, + {UserRole("invalid"), "", true}, + } + + for _, tt := range tests { + t.Run(string(tt.role), func(t *testing.T) { + data, err := json.Marshal(tt.role) + if tt.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, string(data)) + } + }) + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestUserRole_UnmarshalJSON(t *testing.T) { + tests := []struct { + data string + expected UserRole + hasError bool + }{ + {`"admin"`, UserRoleAdmin, false}, + {`"user"`, UserRoleUser, false}, + {`"invalid"`, "", true}, + } + + for _, tt := range tests { + t.Run(tt.data, func(t *testing.T) { + var role UserRole + err := json.Unmarshal([]byte(tt.data), &role) + if tt.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, role) + } + }) + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestUserRole_Scan(t *testing.T) { + tests := []struct { + input interface{} + expected UserRole + hasError bool + }{ + {"admin", UserRoleAdmin, false}, + {"user", UserRoleUser, false}, + {"invalid", "", true}, + } + + for _, tt := range tests { + t.Run(tt.input.(string), func(t *testing.T) { + var role UserRole + err := role.Scan(tt.input) + if tt.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, role) + } + }) + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +func TestUserRole_Value(t *testing.T) { + tests := []struct { + role UserRole + expected driver.Value + hasError bool + }{ + {UserRoleAdmin, "admin", false}, + {UserRoleUser, "user", false}, + {UserRole("invalid"), nil, true}, + } + + for _, tt := range tests { + t.Run(string(tt.role), func(t *testing.T) { + value, err := tt.role.Value() + if tt.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, value) + } + }) + } +} diff --git a/utils/utils.go b/utils/utils.go index 3df79f9..a43fa96 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -43,6 +43,8 @@ func TrimQuotes(s string) string { // are correctly interpreted. If the path starts with a drive letter, it appends a // backslash (\) to paths like "C:" to make them "C:\", and inserts a backslash in paths // like "C:folder" to make them "C:\folder" +// +// Skipped on non-Windows platforms func NormalizeWindowsDrive(path string) string { if runtime.GOOS == "windows" { if len(path) >= 2 && path[1] == ':' { @@ -113,6 +115,18 @@ func EscapeBackslashes(path string) string { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// CheckTruth check string true or not +func CheckTruth(vals ...string) bool { + for _, val := range vals { + if val != "" && !strings.EqualFold(val, "false") { + return true + } + } + return false +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // DiffSliceOfStructsByKey takes in two slices of type T (left and right) and a key (string) as // arguments. The key defines the which key to use when comparing. // @@ -303,3 +317,24 @@ func Map[T, V any](ts []T, fn func(T) V) []V { } return result } + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// snakeCase converts a string to snake_case +func SnakeCase(s string) string { + var b strings.Builder + b.Grow(len(s) + 5) + + for i, r := range s { + // Check if the current rune is uppercase + if i > 0 && 'A' <= r && r <= 'Z' { + // Add underscore if previous rune is lowercase (or non-uppercase letter) + if 'a' <= rune(s[i-1]) && rune(s[i-1]) <= 'z' { + b.WriteByte('_') + } + } + b.WriteRune(r) + } + + return strings.ToLower(b.String()) +} diff --git a/utils/utils_test.go b/utils/utils_test.go index 734a829..92c1b04 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -5,6 +5,7 @@ import ( "runtime" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -32,7 +33,7 @@ func Test_TrimQuotes(t *testing.T) { func Test_DecodeString(t *testing.T) { t.Run("empty", func(t *testing.T) { res, err := DecodeString("") - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, "", res) }) @@ -50,7 +51,7 @@ func Test_DecodeString(t *testing.T) { t.Run("success", func(t *testing.T) { res, err := DecodeString("JTJGdGVzdCUyRmRhdGE=") - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, "/test/data", res) }) } @@ -71,6 +72,29 @@ func Test_EncodeString(t *testing.T) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +func Test_CheckTruth(t *testing.T) { + checkTruthTests := []struct { + v string + out bool + }{ + {"123", true}, + {"true", true}, + {"", false}, + {"false", false}, + {"False", false}, + {"FALSE", false}, + {"\u0046alse", false}, + } + + for _, test := range checkTruthTests { + t.Run(test.v, func(t *testing.T) { + assert.Equal(t, test.out, CheckTruth(test.v)) + }) + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + func Test_DiffSliceOfStructsByKey(t *testing.T) { // Struct for testing type testStruct struct { @@ -101,7 +125,7 @@ func Test_DiffSliceOfStructsByKey(t *testing.T) { t.Run("both empty", func(t *testing.T) { leftDiff, rightDiff, err := DiffSliceOfStructsByKey[testStruct](nil, nil, "") - require.Nil(t, err) + require.NoError(t, err) require.Nil(t, leftDiff) require.Nil(t, rightDiff) }) @@ -114,7 +138,7 @@ func Test_DiffSliceOfStructsByKey(t *testing.T) { } leftDiff, rightDiff, err := DiffSliceOfStructsByKey(nil, right, "ID") - require.Nil(t, err) + require.NoError(t, err) require.Empty(t, leftDiff) require.Len(t, rightDiff, 5) }) @@ -127,7 +151,7 @@ func Test_DiffSliceOfStructsByKey(t *testing.T) { } leftDiff, rightDiff, err := DiffSliceOfStructsByKey(left, nil, "ID") - require.Nil(t, err) + require.NoError(t, err) require.Len(t, leftDiff, 5) require.Empty(t, rightDiff) }) @@ -142,7 +166,7 @@ func Test_DiffSliceOfStructsByKey(t *testing.T) { } leftDiff, rightDiff, err := DiffSliceOfStructsByKey(left, right, "ID") - require.Nil(t, err) + require.NoError(t, err) require.Empty(t, leftDiff) require.Empty(t, rightDiff) }) @@ -157,7 +181,7 @@ func Test_DiffSliceOfStructsByKey(t *testing.T) { } leftDiff, rightDiff, err := DiffSliceOfStructsByKey(left, right, "ID") - require.Nil(t, err) + require.NoError(t, err) require.Len(t, leftDiff, 5) require.Len(t, rightDiff, 5) }) @@ -178,7 +202,7 @@ func Test_DiffSliceOfStructsByKey(t *testing.T) { right = append(right, left[0]) leftDiff, rightDiff, err := DiffSliceOfStructsByKey(left, right, "ID") - require.Nil(t, err) + require.NoError(t, err) require.Len(t, leftDiff, 4) require.Len(t, rightDiff, 3) }) @@ -196,14 +220,14 @@ func Test_DiffSliceOfStructsByKey(t *testing.T) { left = append(left, &testStruct{ID: 5, Title: "Test"}) leftDiff, rightDiff, err := DiffSliceOfStructsByKey(left, right, "ID") - require.Nil(t, err) + require.NoError(t, err) require.Len(t, leftDiff, 1) require.Zero(t, len(rightDiff)) // Give right 1 extra (plus the new left one) right = append(right, left[len(left)-1], &testStruct{ID: 6, Title: "Test"}) leftDiff, rightDiff, err = DiffSliceOfStructsByKey(left, right, "ID") - require.Nil(t, err) + require.NoError(t, err) require.Zero(t, len(leftDiff)) require.Len(t, rightDiff, 1) }) @@ -379,3 +403,31 @@ func Test_EscapeBackslashes(t *testing.T) { require.Equal(t, test.expected, result) } } + +// ------------------------------------------------------- + +func Test_SnakeCase(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"camelCase", "camel_case"}, + {"PascalCase", "pascal_case"}, + {"snake_case", "snake_case"}, + {"Already_Snake_Case", "already_snake_case"}, + {"UpperCase", "upper_case"}, + {"lowercase", "lowercase"}, + {"SimpleTest", "simple_test"}, + {"TestID", "test_id"}, + {"AnotherTestCase", "another_test_case"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + actual := SnakeCase(tt.input) + if actual != tt.expected { + t.Errorf("snakeCase(%q) = %q; want %q", tt.input, actual, tt.expected) + } + }) + } +}