Skip to content

Commit

Permalink
[Kernel] Update kernel/examples to work with latest API changes
Browse files Browse the repository at this point in the history
## Description
There have been changes to Kernel APIs since the last update to the `kernel/examples`. Fix the examples to work with the latest API changes.

## How was this patch tested?
Manually ran the `<repo-root>/kernel/examples/run-kernel-examples.py` and verified the build and display results are valid.
  • Loading branch information
vkorukanti authored and xupefei committed Oct 31, 2023
1 parent 76be591 commit c21d491
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 56 deletions.
Empty file modified kernel/examples/run-kernel-examples.py
100644 → 100755
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
package io.delta.kernel.examples;

import java.io.IOException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
Expand All @@ -35,14 +38,12 @@

import io.delta.kernel.TableNotFoundException;
import io.delta.kernel.client.TableClient;
import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.data.ColumnarBatch;
import io.delta.kernel.data.DataReadResult;
import io.delta.kernel.types.StructField;
import io.delta.kernel.types.StructType;
import io.delta.kernel.data.FilteredColumnarBatch;
import io.delta.kernel.data.Row;
import io.delta.kernel.types.*;
import io.delta.kernel.utils.CloseableIterator;

import io.delta.kernel.defaults.client.DefaultTableClient;
import io.delta.kernel.defaults.internal.data.vector.VectorUtils;

/**
* Base class for reading Delta Lake tables using the Delta Kernel APIs.
Expand Down Expand Up @@ -88,18 +89,18 @@ protected static StructType pruneSchema(StructType baseSchema, Optional<List<Str
return new StructType(selectedFields);
}

protected static int printData(DataReadResult dataReadResult, int maxRowsToPrint) {
protected static int printData(FilteredColumnarBatch data, int maxRowsToPrint) {
int printedRowCount = 0;
ColumnarBatch data = dataReadResult.getData();
Optional<ColumnVector> selectionVector = dataReadResult.getSelectionVector();
for (int rowId = 0; rowId < data.getSize(); rowId++) {
if (!selectionVector.isPresent() || selectionVector.get().getBoolean(rowId)) {
printRow(data, rowId);
try (CloseableIterator<Row> rows = data.getRows()) {
while (rows.hasNext()) {
printRow(rows.next());
printedRowCount++;
if (printedRowCount == maxRowsToPrint) {
break;
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return printedRowCount;
}
Expand All @@ -108,12 +109,11 @@ protected static void printSchema(StructType schema) {
System.out.printf(formatter(schema.length()), schema.fieldNames().toArray(new String[0]));
}

protected static void printRow(ColumnarBatch batch, int rowId) {
int numCols = batch.getSchema().length();
Object[] rowValues = IntStream.range(0, numCols).mapToObj(colOrdinal -> {
ColumnVector columnVector = batch.getColumnVector(colOrdinal);
return VectorUtils.getValueAsObject(columnVector, rowId);
}).toArray();
protected static void printRow(Row row){
int numCols = row.getSchema().length();
Object[] rowValues = IntStream.range(0, numCols)
.mapToObj(colOrdinal -> getValue(row, colOrdinal))
.toArray();

// TODO: Need to handle the Row, Map, Array, Timestamp, Date types specially to
// print them in the format they need. Copy this code from Spark CLI.
Expand Down Expand Up @@ -178,5 +178,52 @@ private static String formatter(int length) {
.mapToObj(i -> "%20s")
.collect(Collectors.joining("|")) + "\n";
}

private static String getValue(Row row, int columnOrdinal) {
DataType dataType = row.getSchema().at(columnOrdinal).getDataType();
if (row.isNullAt(columnOrdinal)) {
return null;
} else if (dataType instanceof BooleanType) {
return Boolean.toString(row.getBoolean(columnOrdinal));
} else if (dataType instanceof ByteType) {
return Byte.toString(row.getByte(columnOrdinal));
} else if (dataType instanceof ShortType) {
return Short.toString(row.getShort(columnOrdinal));
} else if (dataType instanceof IntegerType) {
return Integer.toString(row.getInt(columnOrdinal));
} else if (dataType instanceof DateType) {
// DateType data is stored internally as the number of days since 1970-01-01
int daysSinceEpochUTC = row.getInt(columnOrdinal);
return LocalDate.ofEpochDay(daysSinceEpochUTC).toString();
} else if (dataType instanceof LongType) {
return Long.toString(row.getLong(columnOrdinal));
} else if (dataType instanceof TimestampType) {
// TimestampType data is stored internally as the number of microseconds since epoch
long microSecsSinceEpochUTC = row.getLong(columnOrdinal);
LocalDateTime dateTime = LocalDateTime.ofEpochSecond(
microSecsSinceEpochUTC / 1_000_000 /* epochSecond */,
(int) (1000 * microSecsSinceEpochUTC % 1_000_000) /* nanoOfSecond */,
ZoneOffset.UTC);
return dateTime.toString();
} else if (dataType instanceof FloatType) {
return Float.toString(row.getFloat(columnOrdinal));
} else if (dataType instanceof DoubleType) {
return Double.toString(row.getDouble(columnOrdinal));
} else if (dataType instanceof StringType) {
return row.getString(columnOrdinal);
} else if (dataType instanceof BinaryType) {
return new String(row.getBinary(columnOrdinal));
} else if (dataType instanceof DecimalType) {
return row.getDecimal(columnOrdinal).toString();
} else if (dataType instanceof StructType) {
return "TODO: struct value";
} else if (dataType instanceof ArrayType) {
return "TODO: list value";
} else if (dataType instanceof MapType) {
return "TODO: map value";
} else {
throw new UnsupportedOperationException("unsupported data type: " + dataType);
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
*/
package io.delta.kernel.examples;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ArrayBlockingQueue;
Expand All @@ -25,6 +23,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Option;
Expand All @@ -35,8 +34,7 @@
import io.delta.kernel.Table;
import io.delta.kernel.TableNotFoundException;
import io.delta.kernel.client.TableClient;
import io.delta.kernel.data.ColumnarBatch;
import io.delta.kernel.data.DataReadResult;
import io.delta.kernel.data.FilteredColumnarBatch;
import io.delta.kernel.data.Row;
import io.delta.kernel.examples.utils.RowSerDe;
import io.delta.kernel.types.StructType;
Expand Down Expand Up @@ -80,7 +78,7 @@ public MultiThreadedTableReader(int numThreads, String tablePath) {

public void show(int limit, Optional<List<String>> columnsOpt)
throws TableNotFoundException {
Table table = Table.forPath(tablePath);
Table table = Table.forPath(tableClient, tablePath);
Snapshot snapshot = table.getLatestSnapshot(tableClient);
StructType readSchema = pruneSchema(snapshot.getSchema(tableClient), columnsOpt);

Expand Down Expand Up @@ -155,6 +153,7 @@ private class Reader {
private final BlockingQueue<ScanFile> workQueue = new ArrayBlockingQueue<>(20);

private int readRecordCount; // Data read so far.
private AtomicReference<Exception> error = new AtomicReference<>();

Reader(int limit) {
this.limit = limit;
Expand Down Expand Up @@ -185,31 +184,35 @@ void readData(StructType readSchema, Snapshot snapshot) {
} finally {
stopSignal.set(true);
executorService.shutdownNow();
if (error.get() != null) {
throw new RuntimeException(error.get());
}
}
}

private Runnable workGenerator(Scan scan) {
return (() -> {
try {
Row scanStateRow = scan.getScanState(tableClient);
CloseableIterator<ColumnarBatch> scanFileIter = scan.getScanFiles(tableClient);
Row scanStateRow = scan.getScanState(tableClient);
try(CloseableIterator<FilteredColumnarBatch> scanFileIter =
scan.getScanFiles(tableClient)) {

while (scanFileIter.hasNext() && !stopSignal.get()) {
ColumnarBatch scanFileBatch = scanFileIter.next();
try (CloseableIterator<Row> scanFileRows = scanFileBatch.getRows()) {
try (CloseableIterator<Row> scanFileRows = scanFileIter.next().getRows()) {
while (scanFileRows.hasNext() && !stopSignal.get()) {
workQueue.put(new ScanFile(scanStateRow, scanFileRows.next()));
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
}

for (int i = 0; i < numThreads; i++) {
// poison pill for each worker threads to stop the work.
workQueue.put(ScanFile.POISON_PILL);
}
} catch (InterruptedException ie) {
System.out.print("Work generator is interrupted");
} catch (Exception e) {
error.compareAndSet(null /* expected */, e);
throw new RuntimeException(e);
}
});
}
Expand All @@ -221,7 +224,7 @@ private Runnable workConsumer(int workerId) {
if (work == ScanFile.POISON_PILL) {
return; // exit as there are no more work units
}
try (CloseableIterator<DataReadResult> dataIter = Scan.readData(
try (CloseableIterator<FilteredColumnarBatch> dataIter = Scan.readData(
tableClient,
work.getScanRow(tableClient),
Utils.singletonCloseableIterator(work.getScanFileRow(tableClient)),
Expand All @@ -233,10 +236,11 @@ private Runnable workConsumer(int workerId) {
}
}
}
} catch (IOException ioe) {
throw new UncheckedIOException(ioe);
} catch (InterruptedException ie) {
System.out.printf("Worker %d is interrupted." + workerId);
} catch (Exception e) {
error.compareAndSet(null /* expected */, e);
throw new RuntimeException(e);
} finally {
countDownLatch.countDown();
}
Expand All @@ -246,12 +250,12 @@ private Runnable workConsumer(int workerId) {
/**
* Returns true when sufficient amount of rows are received
*/
private boolean printDataBatch(DataReadResult dataReadResult) {
private boolean printDataBatch(FilteredColumnarBatch data) {
synchronized (this) {
if (readRecordCount >= limit) {
return true;
}
readRecordCount += printData(dataReadResult, limit - readRecordCount);
readRecordCount += printData(data, limit - readRecordCount);
return readRecordCount >= limit;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
import io.delta.kernel.Snapshot;
import io.delta.kernel.Table;
import io.delta.kernel.TableNotFoundException;
import io.delta.kernel.data.ColumnarBatch;
import io.delta.kernel.data.DataReadResult;
import io.delta.kernel.data.FilteredColumnarBatch;
import io.delta.kernel.data.Row;
import io.delta.kernel.types.StructType;
import io.delta.kernel.utils.CloseableIterator;
Expand All @@ -53,7 +52,7 @@ public SingleThreadedTableReader(String tablePath) {
@Override
public void show(int limit, Optional<List<String>> columnsOpt)
throws TableNotFoundException, IOException {
Table table = Table.forPath(tablePath);
Table table = Table.forPath(tableClient, tablePath);
Snapshot snapshot = table.getLatestSnapshot(tableClient);
StructType readSchema = pruneSchema(snapshot.getSchema(tableClient), columnsOpt);

Expand Down Expand Up @@ -93,19 +92,19 @@ private void readData(
printSchema(readSchema);

Row scanState = scan.getScanState(tableClient);
CloseableIterator<ColumnarBatch> scanFileIter = scan.getScanFiles(tableClient);
CloseableIterator<FilteredColumnarBatch> scanFileIter = scan.getScanFiles(tableClient);

int readRecordCount = 0;
try {
while (scanFileIter.hasNext()) {
try (CloseableIterator<DataReadResult> data =
try (CloseableIterator<FilteredColumnarBatch> data =
Scan.readData(
tableClient,
scanState,
scanFileIter.next().getRows(),
Optional.empty())) {
while (data.hasNext()) {
DataReadResult dataReadResult = data.next();
FilteredColumnarBatch dataReadResult = data.next();
readRecordCount += printData(dataReadResult, maxRowCount - readRecordCount);
if (readRecordCount >= maxRowCount) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,8 @@

import io.delta.kernel.client.TableClient;
import io.delta.kernel.data.Row;
import io.delta.kernel.types.ArrayType;
import io.delta.kernel.types.BooleanType;
import io.delta.kernel.types.ByteType;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.DoubleType;
import io.delta.kernel.types.FloatType;
import io.delta.kernel.types.IntegerType;
import io.delta.kernel.types.LongType;
import io.delta.kernel.types.MapType;
import io.delta.kernel.types.ShortType;
import io.delta.kernel.types.StringType;
import io.delta.kernel.types.StructField;
import io.delta.kernel.types.StructType;
import io.delta.kernel.types.*;
import io.delta.kernel.utils.VectorUtils;

import io.delta.kernel.internal.types.TableSchemaSerDe;

Expand Down Expand Up @@ -111,12 +100,16 @@ private static Map<String, Object> convertRowToJsonObject(Row row) {
value = row.getFloat(fieldId);
} else if (fieldType instanceof DoubleType) {
value = row.getDouble(fieldId);
} else if (fieldType instanceof DateType) {
value = row.getInt(fieldId);
} else if (fieldType instanceof TimestampType) {
value = row.getLong(fieldId);
} else if (fieldType instanceof StringType) {
value = row.getString(fieldId);
} else if (fieldType instanceof ArrayType) {
value = row.getArray(fieldId);
value = VectorUtils.toJavaList(row.getArray(fieldId));
} else if (fieldType instanceof MapType) {
value = row.getMap(fieldId);
value = VectorUtils.toJavaMap(row.getMap(fieldId));
} else if (fieldType instanceof StructType) {
Row subRow = row.getStruct(fieldId);
value = convertRowToJsonObject(subRow);
Expand Down

0 comments on commit c21d491

Please sign in to comment.