Skip to content

Commit

Permalink
Fix copier expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
irees committed Dec 12, 2024
1 parent 95a4501 commit bde025e
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 100 deletions.
140 changes: 86 additions & 54 deletions copier/copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,46 @@ func (copier *Copier) checkEntity(ent tt.Entity) error {
return nil
}

func (copier *Copier) writerAddEntities(okEnts []tt.Entity) error {
if len(okEnts) == 0 {
return nil
}
efn := okEnts[0].Filename()
sids := make([]string, len(okEnts))
for i, ent := range okEnts {
sids[i] = ent.EntityID()
}
eids, err := copier.Writer.AddEntities(okEnts)
if err != nil {
copier.sublogger.Error().Err(err).Str("filename", efn).Msgf("critical error: failed to write %d entities", len(okEnts))
return err
}
if len(eids) != len(okEnts) {
return fmt.Errorf("expected to write %d entities, got %d", len(okEnts), len(eids))
}
for i, ent := range okEnts {
sid := sids[i]
eid := eids[i]
copier.EntityMap.Set(efn, sid, eid)
if entExt, ok := ent.(tt.EntityWithGroupKey); ok {
if groupKey, groupId := entExt.GroupKey(); groupId != "" {
copier.EntityMap.Set(fmt.Sprintf("%s:%s", efn, groupKey), groupId, groupId)
}
}
}
copier.result.EntityCount[efn] += len(okEnts)

// AfterWriters
for i, eid := range eids {
for _, v := range copier.afterWriters {
if err := v.AfterWrite(eid, okEnts[i], copier.EntityMap); err != nil {
return err
}
}
}
return nil
}

//////////////////////////////////
////////// Copy Methods //////////
//////////////////////////////////
Expand Down Expand Up @@ -959,61 +999,40 @@ func copyEntities[T tt.Entity](copier *Copier, ents []T) ([]tt.Entity, error) {
if len(ents) == 0 {
return nil, nil
}
okEnts := make([]tt.Entity, 0, len(ents))
expandedEnts := make([]tt.Entity, 0, len(ents))
for _, ent := range ents {
ent := ent
expanded := false
for _, f := range copier.expandFilters {
if exp, ok, err := f.Expand(ent, copier.EntityMap); err != nil {
// skip
if a, ok, err := f.Expand(ent, copier.EntityMap); err != nil {
log.Error().Err(err).Msg("failed to expand")
} else if ok {
expanded = true
if err := copier.checkEntity(ent); err == nil {
okEnts = append(okEnts, exp...)
}
expandedEnts = append(expandedEnts, a...)
}
}
if !expanded {
if err := copier.checkEntity(ent); err == nil {
okEnts = append(okEnts, ent)
}
expandedEnts = append(expandedEnts, ent)
}
}
if len(okEnts) == 0 {
return nil, nil
// Group by filename, retaining input order
batchedEnts := batchEntFilenames(expandedEnts)
if len(batchedEnts) == 0 {
batchedEnts = append(batchedEnts, expandedEnts)
}
efn := okEnts[0].Filename()
sids := make([]string, len(okEnts))
for i, ent := range okEnts {
sids[i] = ent.EntityID()
}
eids, err := copier.Writer.AddEntities(okEnts)
if err != nil {
copier.sublogger.Error().Err(err).Str("filename", efn).Msgf("critical error: failed to write %d entities", len(okEnts))
return nil, err
}
if len(eids) != len(okEnts) {
return nil, fmt.Errorf("expected to write %d entities, got %d", len(okEnts), len(eids))
}
for i, ent := range okEnts {
sid := sids[i]
eid := eids[i]
copier.EntityMap.Set(efn, sid, eid)
if entExt, ok := ent.(tt.EntityWithGroupKey); ok {
if groupKey, groupId := entExt.GroupKey(); groupId != "" {
copier.EntityMap.Set(fmt.Sprintf("%s:%s", efn, groupKey), groupId, groupId)
// Write in filename batches
okEnts := make([]tt.Entity, 0, len(expandedEnts))
for _, batch := range batchedEnts {
checkedEnts := make([]tt.Entity, 0, len(batch))
for _, ent := range batch {
if err := copier.checkEntity(ent); err == nil {
checkedEnts = append(checkedEnts, ent)
}
}
}
copier.result.EntityCount[efn] += len(okEnts)

// AfterWriters
for i, eid := range eids {
for _, v := range copier.afterWriters {
if err := v.AfterWrite(eid, okEnts[i], copier.EntityMap); err != nil {
return nil, err
}
if err := copier.writerAddEntities(checkedEnts); err != nil {
return nil, err
}
okEnts = append(okEnts, checkedEnts...)
}
return okEnts, nil
}
Expand Down Expand Up @@ -1065,23 +1084,36 @@ func batchChan[T any](it chan T, batchSize int, filt func(T) bool) iter.Seq[[]T]
}
}

func batchMap[K comparable, T any](m map[K]T, batchSize int) iter.Seq[[]T] {
return func(yield func([]T) bool) {
var ents []T
for _, v := range m {
ents = append(ents, v)
if len(ents) < batchSize {
continue
}
if !yield(ents) {
return
}
ents = nil
func batchEntFilenames(ents []tt.Entity) [][]tt.Entity {
mixedFns := false
lastFn := ents[0].Filename()
for _, ent := range ents {
fn := ent.Filename()
if fn != lastFn {
mixedFns = true
break
}
if len(ents) > 0 {
yield(ents)
}
if !mixedFns {
return nil
}
var batches [][]tt.Entity
var batch []tt.Entity
lastFn = ents[0].Filename()
for _, ent := range ents {
if fn := ent.Filename(); fn == lastFn {
batch = append(batch, ent)
} else {
lastFn = fn
batches = append(batches, batch)
batch = nil
batch = append(batch, ent)
}
}
if len(batch) > 0 {
batches = append(batches, batch)
}
return batches
}

func shapeLines(it chan []gtfs.Shape) chan service.ShapeLine {
Expand Down
86 changes: 40 additions & 46 deletions tlcsv/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,22 @@ func (writer *Writer) NewReader() (adapters.Reader, error) {
return NewReader(writer.WriterAdapter.Path())
}

// AddEntity writes an entity to the output.
func (writer *Writer) AddEntity(ent tt.Entity) (string, error) {
eids, err := writer.AddEntities([]tt.Entity{ent})
if err != nil {
return "", err
}
if len(eids) == 0 {
return "", errors.New("did not write expected number of entities")
}
return eids[0], nil
}

// AddEntities writes entities to the output.
func (writer *Writer) AddEntities(ents []tt.Entity) ([]string, error) {
eids := []string{}
if len(ents) == 0 {
return eids, nil
}

// Normal write path
ent := ents[0]
efn := ents[0].Filename()
for _, ent := range ents {
if efn != ent.Filename() {
return eids, errors.New("all entities must be same type")
}
// Horrible special case bug fix
if v, ok := ent.(*gtfs.Stop); ok {
c := v.Coordinates()
v.StopLon.Set(c[0])
v.StopLat.Set(c[1])
}
return nil, nil
}

// Awful Ugly Hack to Flatten entities
Expand All @@ -96,19 +92,43 @@ func (writer *Writer) AddEntities(ents []tt.Entity) ([]string, error) {
}
}
if len(expandedEnts) > 0 {
_, err := writer.AddEntities(expandedEnts)
_, err := writer.addBatch(expandedEnts)
if err != nil {
return nil, err
}
return originalEids, nil
}

// Normal write path
return writer.addBatch(ents)
}

func (writer *Writer) addBatch(ents []tt.Entity) ([]string, error) {
var eids []string
if len(ents) == 0 {
return nil, nil
}

ent := ents[0]
efn := ents[0].Filename()
for _, ent := range ents {
if efn != ent.Filename() {
return nil, errors.New("all entities must be same type")
}
// Horrible special case bug fix
if v, ok := ent.(*gtfs.Stop); ok {
c := v.Coordinates()
v.StopLon.Set(c[0])
v.StopLat.Set(c[1])
}
}

extraHeader := writer.extraHeaders[efn]
header, ok := writer.headers[efn]
if !ok {
h, err := dumpHeader(ent)
if err != nil {
return eids, err
return nil, err
}
header = h
if extEnt, ok2 := ent.(tt.EntityWithExtra); ok2 && writer.writeExtraColumns {
Expand All @@ -128,7 +148,7 @@ func (writer *Writer) AddEntities(ents []tt.Entity) ([]string, error) {
}
row, err := dumpRow(ent, header)
if err != nil {
return eids, err
return nil, err
}
if len(extraHeader) > 0 {
if extEnt, ok := ent.(tt.EntityWithExtra); ok {
Expand All @@ -146,32 +166,6 @@ func (writer *Writer) AddEntities(ents []tt.Entity) ([]string, error) {
return eids, err
}

// AddEntity writes an entity to the output.
func (writer *Writer) AddEntity(ent tt.Entity) (string, error) {
// var eids []string
// var err error
// if v, ok := ent.(canFlatten); ok {
// eids, err = writer.AddEntities(v.Flatten())
// } else {
// eids, err = writer.AddEntities([]tt.Entity{ent})
// }
// if err != nil {
// return "", err
// }
// if len(eids) == 0 {
// return "", errors.New("did not write expected number of entities")
// }
// return eids[0], nil
eids, err := writer.AddEntities([]tt.Entity{ent})
if err != nil {
return "", err
}
if len(eids) == 0 {
return "", errors.New("did not write expected number of entities")
}
return eids[0], nil
}

type canFlatten interface {
Flatten() []tt.Entity
}

0 comments on commit bde025e

Please sign in to comment.