diff --git a/billing/subscription/service.go b/billing/subscription/service.go index e89ff0c11..90e2d6542 100644 --- a/billing/subscription/service.go +++ b/billing/subscription/service.go @@ -202,25 +202,43 @@ func (s *Service) SyncWithProvider(ctx context.Context, customr customer.Custome } // update plan id if it's changed - planByStripeSubscription, err := s.findPlanByStripeSubscription(ctx, stripeSubscription) - if err != nil { - subErrs = append(subErrs, fmt.Errorf("failed to find Plan By Stripe Subscription: %w", err)) + currentPlanID, nextPlanID, err := s.getPlanFromSchedule(ctx, stripeSchedule) + if errors.Is(err, ErrNoPhaseActive) { + currentPlan, err := s.findPlanByStripeSubscription(ctx, stripeSubscription) + if err != nil { + subErrs = append(subErrs, fmt.Errorf("failed to find plan from stripe subscription: %w", err)) + continue + } + currentPlanID = currentPlan.ID + } else if err != nil { + subErrs = append(subErrs, fmt.Errorf("failed to find plan from stripe schedule: %w", err)) continue } - if sub.PlanID != planByStripeSubscription.ID { - sub.PlanID = planByStripeSubscription.ID + + if sub.PlanID != currentPlanID { + sub.PlanID = currentPlanID updateNeeded = true } - // update sub change if it's changed - if stripeSubscription.Schedule != nil && - stripeSubscription.Schedule.CurrentPhase != nil { - if sub.Phase.EffectiveAt.IsZero() { + // update phase if it's changed + if sub.Phase.PlanID != nextPlanID { + sub.Phase.PlanID = nextPlanID + updateNeeded = true + } + if stripeSubscription.Schedule != nil { + if stripeSubscription.Schedule.CurrentPhase == nil && + sub.Phase.EffectiveAt.Unix() > 0 { + sub.Phase.EffectiveAt = time.Unix(0, 0) + updateNeeded = true + } + if stripeSubscription.Schedule.CurrentPhase != nil && + sub.Phase.EffectiveAt.Unix() != stripeSubscription.Schedule.CurrentPhase.EndDate { sub.Phase.EffectiveAt = time.Unix(stripeSubscription.Schedule.CurrentPhase.EndDate, 0) updateNeeded = true } } + // update sub change if it's changed if updateNeeded { if _, err := s.repository.UpdateByID(ctx, sub); err != nil { return err @@ -280,7 +298,13 @@ func (s *Service) Cancel(ctx context.Context, id string, immediate bool) (Subscr return sub, nil } - if immediate { + // check if schedule exists + _, stripeSchedule, err := s.createOrGetSchedule(ctx, sub) + if err != nil { + return sub, err + } + + if immediate || stripeSchedule == nil { stripeSubscription, err := s.stripeClient.Subscriptions.Cancel(sub.ProviderID, &stripe.SubscriptionCancelParams{ Params: stripe.Params{ Context: ctx, @@ -296,12 +320,6 @@ func (s *Service) Cancel(ctx context.Context, id string, immediate bool) (Subscr sub.CanceledAt = time.Unix(stripeSubscription.CanceledAt, 0) } } else { - // check if schedule exists - _, stripeSchedule, err := s.createOrGetSchedule(ctx, sub) - if err != nil { - return sub, err - } - // update schedule to cancel at the end of the current period var currentPhaseItems []*stripe.SubscriptionSchedulePhaseItemParams for _, phase := range stripeSchedule.Phases { @@ -332,6 +350,10 @@ func (s *Service) Cancel(ctx context.Context, id string, immediate bool) (Subscr Currency: stripe.String(currency), StartDate: stripe.Int64(stripeSchedule.CurrentPhase.StartDate), EndDate: endDate, + Metadata: map[string]string{ + "plan_id": sub.PlanID, + "managed_by": "frontier", + }, }, }, EndBehavior: stripe.String(string(stripe.SubscriptionScheduleEndBehaviorCancel)), @@ -352,16 +374,24 @@ func (s *Service) createOrGetSchedule(ctx context.Context, sub Subscription) (*s Params: stripe.Params{ Context: ctx, }, - Expand: []*string{stripe.String("schedule")}, + Expand: []*string{ + stripe.String("schedule"), + stripe.String("schedule.phases.items.price"), + }, }) if err != nil { return nil, nil, fmt.Errorf("failed to get subscription from billing provider: %w", err) } - var stripeSchedule = stripeSubscription.Schedule - if stripeSchedule == nil || stripeScheduleCreateRequired(stripeSchedule) { + if stripeSubscription.Status == stripe.SubscriptionStatusCanceled || + stripeSubscription.Status == stripe.SubscriptionStatusIncomplete || + stripeSubscription.Status == stripe.SubscriptionStatusIncompleteExpired { + return stripeSubscription, nil, nil + } + + if stripeSubscription.Schedule == nil { // no schedule exists, create a new schedule - stripeSchedule, err = s.stripeClient.SubscriptionSchedules.New(&stripe.SubscriptionScheduleParams{ + stripeSubscription.Schedule, err = s.stripeClient.SubscriptionSchedules.New(&stripe.SubscriptionScheduleParams{ Params: stripe.Params{ Context: ctx, }, @@ -371,8 +401,7 @@ func (s *Service) createOrGetSchedule(ctx context.Context, sub Subscription) (*s return nil, nil, fmt.Errorf("failed to create subscription schedule at billing provider: %w", err) } } - stripeSubscription.Schedule = stripeSchedule - return stripeSubscription, stripeSchedule, nil + return stripeSubscription, stripeSubscription.Schedule, nil } func (s *Service) List(ctx context.Context, filter Filter) ([]Subscription, error) { @@ -393,11 +422,12 @@ func (s *Service) List(ctx context.Context, filter Filter) ([]Subscription, erro // UpdateProductQuantity updates the quantity of the product in the subscription // Note: check if we need to handle subscription schedule -func (s *Service) UpdateProductQuantity(ctx context.Context, orgID string, plan plan.Plan, +func (s *Service) UpdateProductQuantity(ctx context.Context, orgID string, currentPlan plan.Plan, stripeSubscription *stripe.Subscription, stripeSchedule *stripe.SubscriptionSchedule) error { - var shouldUpdateSchedule = false + var orgMemberCount int64 = 1 + var err error - currentPhaseItems := s.getCurrentPhaseFromSchedule(stripeSchedule) + // update current subscription currentSubscriptionItems := make([]*stripe.SubscriptionItemsParams, 0, len(stripeSubscription.Items.Data)) for _, item := range stripeSubscription.Items.Data { currentSubscriptionItems = append(currentSubscriptionItems, &stripe.SubscriptionItemsParams{ @@ -408,60 +438,98 @@ func (s *Service) UpdateProductQuantity(ctx context.Context, orgID string, plan }) } - if planFeature, ok := plan.GetUserSeatProduct(); ok { + if planFeature, ok := currentPlan.GetUserSeatProduct(); ok { + var shouldUpdateSubscription = false // get the current quantity - count, err := s.orgService.MemberCount(ctx, orgID) + orgMemberCount, err = s.orgService.MemberCount(ctx, orgID) if err != nil { return fmt.Errorf("failed to get member count: %w", err) } + for _, planProductPrice := range planFeature.Prices { + // check for changes in subscription + for idx, subItemData := range currentSubscriptionItems { + // convert provider price id to system price id and get the product + if planProductPrice.ProviderID == *subItemData.Price { + shouldChangeQuantity, err := s.shouldChangeScheduleQuantity(orgMemberCount, subItemData) + if err != nil { + return err + } + if shouldChangeQuantity { + shouldUpdateSubscription = true + currentSubscriptionItems[idx].Quantity = &orgMemberCount + } + } + } + } + + if shouldUpdateSubscription { + _, err := s.stripeClient.Subscriptions.Update(stripeSubscription.ID, &stripe.SubscriptionParams{ + Params: stripe.Params{ + Context: ctx, + }, + Items: currentSubscriptionItems, + PendingInvoiceItemInterval: getPendingInvoiceItemInterval(currentPlan), + }) + if err != nil { + return fmt.Errorf("failed to update subscription quantity at billing provider: %w", err) + } + } + } + + // if there is a next phase, we will also update all phases of schedule + currentPhase, nextPhase := s.getCurrentAndNextPhaseFromSchedule(stripeSchedule) + if nextPhase == nil { + // no need to update the phases if there is no next phase + return nil + } + + _, nextPlanID, err := s.getPlanFromSchedule(ctx, stripeSchedule) + if errors.Is(err, ErrNoPhaseActive) { + return nil + } + if err != nil { + return fmt.Errorf("failed to get plan from schedule: %w", err) + } + nextPlan, err := s.planService.GetByID(ctx, nextPlanID) + if err != nil { + return fmt.Errorf("failed to get next plan: %w", err) + } + var shouldUpdateSchedule = false + + if planFeature, ok := currentPlan.GetUserSeatProduct(); ok { for _, planProductPrice := range planFeature.Prices { // check for changes in schedule - for idx, subItemData := range currentPhaseItems { - // convert provider price id to system price id and get the feature + for idx, subItemData := range currentPhase.Items { + // convert provider price id to system price id and get the product if planProductPrice.ProviderID == *subItemData.Price { - shouldChangeQuantity := false - switch strings.ToLower(s.config.ProductConfig.SeatChangeBehavior) { - case "exact": - if count != *subItemData.Quantity { - shouldChangeQuantity = true - } - case "incremental": - if count > *subItemData.Quantity { - shouldChangeQuantity = true - } - default: - return fmt.Errorf("invalid seat change behavior: %s", s.config.ProductConfig.SeatChangeBehavior) + shouldChangeQuantity, err := s.shouldChangePhaseQuantity(orgMemberCount, subItemData) + if err != nil { + return err } if shouldChangeQuantity { shouldUpdateSchedule = true - currentPhaseItems[idx].Quantity = &count + currentPhase.Items[idx].Quantity = &orgMemberCount } } } - - // check for changes in subscription - for idx, subItemData := range currentSubscriptionItems { - // convert provider price id to system price id and get the feature + } + } + if planFeature, ok := nextPlan.GetUserSeatProduct(); ok { + for _, planProductPrice := range planFeature.Prices { + // check for changes in schedule + for idx, subItemData := range nextPhase.Items { + // convert provider price id to system price id and get the product if planProductPrice.ProviderID == *subItemData.Price { - shouldChangeQuantity := false - switch strings.ToLower(s.config.ProductConfig.SeatChangeBehavior) { - case "exact": - if count != *subItemData.Quantity { - shouldChangeQuantity = true - } - case "incremental": - if count > *subItemData.Quantity { - shouldChangeQuantity = true - } - default: - return fmt.Errorf("invalid seat change behavior: %s", s.config.ProductConfig.SeatChangeBehavior) + shouldChangeQuantity, err := s.shouldChangePhaseQuantity(orgMemberCount, subItemData) + if err != nil { + return err } if shouldChangeQuantity { shouldUpdateSchedule = true - currentSubscriptionItems[idx].Quantity = &count + nextPhase.Items[idx].Quantity = &orgMemberCount } } } @@ -469,58 +537,57 @@ func (s *Service) UpdateProductQuantity(ctx context.Context, orgID string, plan } if shouldUpdateSchedule { - _, err := s.stripeClient.Subscriptions.Update(stripeSubscription.ID, &stripe.SubscriptionParams{ + _, err = s.stripeClient.SubscriptionSchedules.Update(stripeSchedule.ID, &stripe.SubscriptionScheduleParams{ Params: stripe.Params{ Context: ctx, }, - Items: currentSubscriptionItems, - PendingInvoiceItemInterval: getPendingInvoiceItemInterval(plan), + Phases: []*stripe.SubscriptionSchedulePhaseParams{ + currentPhase, + nextPhase, + }, }) if err != nil { - return fmt.Errorf("failed to update subscription quantity at billing provider: %w", err) - } - - // TODO(kushsharma): check if we need to update the schedule as well - // get all phases of schedule else they will be overwritten/removed - //allPhases := make([]*stripe.SubscriptionSchedulePhaseParams, 0, len(stripeSchedule.Phases)) - //for _, phase := range stripeSchedule.Phases { - // phaseItems := make([]*stripe.SubscriptionSchedulePhaseItemParams, 0, len(phase.Items)) - // if phase.StartDate == stripeSchedule.CurrentPhase.StartDate && - // phase.EndDate == stripeSchedule.CurrentPhase.EndDate { - // phaseItems = currentPhaseItems - // } else { - // for _, item := range phase.Items { - // phaseItems = append(phaseItems, &stripe.SubscriptionSchedulePhaseItemParams{ - // Price: stripe.String(item.Price.ID), - // Quantity: stripe.Int64(item.Quantity), - // Metadata: item.Metadata, - // }) - // } - // } - // - // allPhases = append(allPhases, &stripe.SubscriptionSchedulePhaseParams{ - // Items: phaseItems, - // Currency: stripe.String(string(phase.Currency)), - // StartDate: stripe.Int64(phase.StartDate), - // EndDate: stripe.Int64(phase.EndDate), - // TrialEnd: stripe.Int64(phase.TrialEnd), - // Metadata: phase.Metadata, - // }) - //} - //_, err = s.stripeClient.SubscriptionSchedules.Update(stripeSchedule.ID, &stripe.SubscriptionScheduleParams{ - // Params: stripe.Params{ - // Context: ctx, - // }, - // Phases: allPhases, - //}) - //if err != nil { - // return fmt.Errorf("failed to update subscription schedule at billing provider: %w", err) - //} + return fmt.Errorf("failed to update subscription schedule at billing provider: %w", err) + } } return nil } +func (s *Service) shouldChangeScheduleQuantity(orgMemberCount int64, subItemData *stripe.SubscriptionItemsParams) (bool, error) { + shouldChangeQuantity := false + switch strings.ToLower(s.config.ProductConfig.SeatChangeBehavior) { + case "exact": + if orgMemberCount != *subItemData.Quantity { + shouldChangeQuantity = true + } + case "incremental": + if orgMemberCount > *subItemData.Quantity { + shouldChangeQuantity = true + } + default: + return false, fmt.Errorf("invalid seat change behavior: %s", s.config.ProductConfig.SeatChangeBehavior) + } + return shouldChangeQuantity, nil +} + +func (s *Service) shouldChangePhaseQuantity(orgMemberCount int64, subItemData *stripe.SubscriptionSchedulePhaseItemParams) (bool, error) { + shouldChangeQuantity := false + switch strings.ToLower(s.config.ProductConfig.SeatChangeBehavior) { + case "exact": + if orgMemberCount != *subItemData.Quantity { + shouldChangeQuantity = true + } + case "incremental": + if orgMemberCount > *subItemData.Quantity { + shouldChangeQuantity = true + } + default: + return false, fmt.Errorf("invalid seat change behavior: %s", s.config.ProductConfig.SeatChangeBehavior) + } + return shouldChangeQuantity, nil +} + // ChangePlan changes the plan of the subscription by creating a subscription schedule // it first checks if the schedule is already created, if not it creates a new schedule // using the current subscription as the base and the new plan as the target in upcoming phase. @@ -640,11 +707,19 @@ func (s *Service) ChangePlan(ctx context.Context, id string, changeRequest Chang StartDate: stripe.Int64(stripeSchedule.CurrentPhase.StartDate), EndDate: endDate, EndDateNow: endDateNow, + Metadata: map[string]string{ + "plan_id": planByStripeSubscription.ID, + "managed_by": "frontier", + }, }, { Items: nextPhaseItems, Currency: stripe.String(currency), Iterations: stripe.Int64(1), + Metadata: map[string]string{ + "plan_id": planObj.ID, + "managed_by": "frontier", + }, }, }, EndBehavior: stripe.String("release"), @@ -657,8 +732,13 @@ func (s *Service) ChangePlan(ctx context.Context, id string, changeRequest Chang return change, fmt.Errorf("failed to update subscription schedule at billing provider: %w", err) } - sub.Phase.EffectiveAt = time.Unix(updatedSchedule.Phases[1].StartDate, 0) - sub.Phase.PlanID = planObj.ID + // update subscription with new phase + _, nextPlanID, err := s.getPlanFromSchedule(ctx, updatedSchedule) + if err != nil { + return change, err + } + sub.Phase.EffectiveAt = time.Unix(updatedSchedule.CurrentPhase.EndDate, 0) + sub.Phase.PlanID = nextPlanID sub, err = s.repository.UpdateByID(ctx, sub) if err != nil { return change, err @@ -685,6 +765,100 @@ func (s *Service) getCurrentPhaseFromSchedule(stripeSchedule *stripe.Subscriptio return currentPhaseItems } +func (s *Service) getCurrentAndNextPhaseFromSchedule(stripeSchedule *stripe.SubscriptionSchedule) (*stripe.SubscriptionSchedulePhaseParams, *stripe.SubscriptionSchedulePhaseParams) { + if stripeSchedule == nil || stripeSchedule.CurrentPhase == nil { + return nil, nil + } + var currentPhaseItems []*stripe.SubscriptionSchedulePhaseItemParams + var nextPhaseItems []*stripe.SubscriptionSchedulePhaseItemParams + var currentPhase *stripe.SubscriptionSchedulePhaseParams + var nextPhase *stripe.SubscriptionSchedulePhaseParams + + for _, phase := range stripeSchedule.Phases { + if phase.StartDate == stripeSchedule.CurrentPhase.StartDate { + currentPhaseItems = make([]*stripe.SubscriptionSchedulePhaseItemParams, 0, len(phase.Items)) + for _, item := range phase.Items { + currentPhaseItems = append(currentPhaseItems, &stripe.SubscriptionSchedulePhaseItemParams{ + Price: stripe.String(item.Price.ID), + Quantity: stripe.Int64(item.Quantity), + Metadata: item.Metadata, + }) + } + + currentPhase = &stripe.SubscriptionSchedulePhaseParams{ + Items: currentPhaseItems, + Currency: stripe.String(string(phase.Currency)), + StartDate: stripe.Int64(phase.StartDate), + EndDate: stripe.Int64(phase.EndDate), + Metadata: phase.Metadata, + } + if phase.TrialEnd > 0 { + currentPhase.TrialEnd = stripe.Int64(phase.TrialEnd) + } + } else if phase.StartDate >= stripeSchedule.CurrentPhase.EndDate { + nextPhaseItems = make([]*stripe.SubscriptionSchedulePhaseItemParams, 0, len(phase.Items)) + for _, item := range phase.Items { + nextPhaseItems = append(nextPhaseItems, &stripe.SubscriptionSchedulePhaseItemParams{ + Price: stripe.String(item.Price.ID), + Quantity: stripe.Int64(item.Quantity), + Metadata: item.Metadata, + }) + } + + nextPhase = &stripe.SubscriptionSchedulePhaseParams{ + Items: nextPhaseItems, + Currency: stripe.String(string(phase.Currency)), + StartDate: stripe.Int64(phase.StartDate), + EndDate: stripe.Int64(phase.EndDate), + Metadata: phase.Metadata, + } + if phase.TrialEnd > 0 { + nextPhase.TrialEnd = stripe.Int64(phase.TrialEnd) + } + } + } + + return currentPhase, nextPhase +} + +// todo(kushsharma): return plan instead of id +func (s *Service) getPlanFromSchedule(ctx context.Context, stripeSchedule *stripe.SubscriptionSchedule) (string, string, error) { + if stripeSchedule == nil || stripeSchedule.CurrentPhase == nil { + return "", "", ErrNoPhaseActive + } + var currentPlanID string + var nextPlanID string + for _, phase := range stripeSchedule.Phases { + if phase.StartDate == stripeSchedule.CurrentPhase.StartDate { + if phase.Metadata != nil { + if planID, ok := phase.Metadata["plan_id"]; ok { + currentPlanID = planID + continue + } + } + currentPlan, err := s.findPlanByStripePhase(ctx, phase) + if err != nil { + return "", "", err + } + currentPlanID = currentPlan.ID + } else if phase.StartDate >= stripeSchedule.CurrentPhase.EndDate { + if phase.Metadata != nil { + if planID, ok := phase.Metadata["plan_id"]; ok { + nextPlanID = planID + continue + } + } + + nextPlan, err := s.findPlanByStripePhase(ctx, phase) + if err != nil { + return "", "", err + } + nextPlanID = nextPlan.ID + } + } + return currentPlanID, nextPlanID, nil +} + // CancelUpcomingPhase cancels the scheduled phase of the subscription func (s *Service) CancelUpcomingPhase(ctx context.Context, sub Subscription) error { _, stripeSchedule, err := s.createOrGetSchedule(ctx, sub) @@ -714,6 +888,10 @@ func (s *Service) CancelUpcomingPhase(ctx context.Context, sub Subscription) err Currency: stripe.String(currency), StartDate: stripe.Int64(stripeSchedule.CurrentPhase.StartDate), EndDate: stripe.Int64(stripeSchedule.CurrentPhase.EndDate), + Metadata: map[string]string{ + "plan_id": sub.PlanID, + "managed_by": "frontier", + }, }, }, EndBehavior: stripe.String("release"), @@ -736,12 +914,6 @@ func (s *Service) CancelUpcomingPhase(ctx context.Context, sub Subscription) err return nil } -func stripeScheduleCreateRequired(stripeSchedule *stripe.SubscriptionSchedule) bool { - return stripeSchedule != nil && - (stripeSchedule.Status == stripe.SubscriptionScheduleStatusCanceled || - stripeSchedule.Status == stripe.SubscriptionScheduleStatusReleased) -} - func (s *Service) findPlanByStripeSubscription(ctx context.Context, stripeSubscription *stripe.Subscription) (plan.Plan, error) { // keep plan id in sync based on what products are attached to the subscription // it can change if the user changes the plan using a schedule @@ -778,6 +950,42 @@ func (s *Service) findPlanByStripeSubscription(ctx context.Context, stripeSubscr return plans[0], nil } +func (s *Service) findPlanByStripePhase(ctx context.Context, stripePhase *stripe.SubscriptionSchedulePhase) (plan.Plan, error) { + // keep plan id in sync based on what products are attached to the subscription + // it can change if the user changes the plan using a schedule + var productPlanIDs []string + var interval string + + for _, subStripeItem := range stripePhase.Items { + product, err := s.productService.GetByProviderID(ctx, subStripeItem.Price.Product.ID) + if err != nil { + return plan.Plan{}, fmt.Errorf("failed to get product from billing provider: %w", err) + } + if len(productPlanIDs) == 0 { + productPlanIDs = append(productPlanIDs, product.PlanIDs...) + interval = string(subStripeItem.Price.Recurring.Interval) + continue + } + productPlanIDs = utils.Intersection(productPlanIDs, product.PlanIDs) + } + + plans, err := s.planService.List(ctx, plan.Filter{ + IDs: productPlanIDs, + Interval: interval, + }) + if err != nil { + return plan.Plan{}, err + } + + if len(plans) == 0 { + return plan.Plan{}, fmt.Errorf("no plan found for phase products: %v, interval: %s", productPlanIDs, interval) + } else if len(plans) > 1 { + return plan.Plan{}, fmt.Errorf("multiple plans found for products: %v", plans) + } + + return plans[0], nil +} + func (s *Service) ensureCreditsForPlan(ctx context.Context, customerID string, subPlan plan.Plan) error { txID := uuid.NewSHA1(credit.TxNamespaceUUID, []byte(fmt.Sprintf("%s:%s", subPlan.ID, customerID))).String() if subPlan.OnStartCredits == 0 { diff --git a/billing/subscription/subscription.go b/billing/subscription/subscription.go index ec2232fe7..699fd5990 100644 --- a/billing/subscription/subscription.go +++ b/billing/subscription/subscription.go @@ -1,18 +1,19 @@ package subscription import ( - "errors" + "fmt" "time" "github.com/raystack/frontier/pkg/metadata" ) var ( - ErrNotFound = errors.New("subscription not found") - ErrInvalidUUID = errors.New("invalid syntax of uuid") - ErrInvalidID = errors.New("invalid subscription id") - ErrInvalidDetail = errors.New("invalid subscription detail") - ErrAlreadyOnSamePlan = errors.New("already on the same plan") + ErrNotFound = fmt.Errorf("subscription not found") + ErrInvalidUUID = fmt.Errorf("invalid syntax of uuid") + ErrInvalidID = fmt.Errorf("invalid subscription id") + ErrInvalidDetail = fmt.Errorf("invalid subscription detail") + ErrAlreadyOnSamePlan = fmt.Errorf("already on the same plan") + ErrNoPhaseActive = fmt.Errorf("no phase active") ) type State string