Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added shapley values to rest scorer endpoint #241

Merged
merged 21 commits into from
Aug 27, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public ResponseEntity<String> getModelId() {

@Override
@RequestMapping("/invocations")
public ResponseEntity<ScoreResponse> getScore(ScoreRequest request) {
public ResponseEntity<ScoreResponse> getScore(ScoreRequest request, Boolean shapleyResults) {
try {
log.info("Got scoring request");
return ResponseEntity.ok(scorer.score(request));
Expand Down
26 changes: 26 additions & 0 deletions common/swagger/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ paths:
produces:
- application/json
parameters:
- in: query
Rajimut marked this conversation as resolved.
Show resolved Hide resolved
name: shapley_results
type: boolean
required: false
description: The boolean to say if Shap values are needed.
- name: payload
in: body
required: true
Expand Down Expand Up @@ -203,6 +208,27 @@ definitions:
type: array
items:
$ref: '#/definitions/Row'
inputShapleyContributions:
Rajimut marked this conversation as resolved.
Show resolved Hide resolved
type: object
items:
$ref: '#/definitions/ShapleyResponse'
ShapleyResponse:
description: >
An Object that contains list of columns along with their shap values
type: object
properties:
fields:
description: >
An array holding the names of fields in the order of appearance in the rows of the `contributions` property.
type: array
items:
type: string
contributions:
description: >
An array of rows consisting of the shapley contributions output corresponding to columns in the fields
type: array
items:
$ref: '#/definitions/Row'
DataField:
type: object
properties:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ai.h2o.mojos.deploy.common.rest.model.Row;
import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest;
import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse;
import ai.h2o.mojos.deploy.common.rest.model.ShapleyResponse;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import com.google.common.base.Strings;
import java.util.ArrayList;
Expand Down Expand Up @@ -46,6 +47,24 @@ public ScoreResponse apply(MojoFrame mojoFrame, ScoreRequest scoreRequest) {
return response;
}

/**
* Converts the resulting shap values {@link MojoFrame} into the API response object {@link
* ShapleyResponse}.
*/
public ShapleyResponse getShapleyResponse(MojoFrame shapleyMojoFrame) {
List<Row> outputRows = Stream.generate(Row::new).limit(shapleyMojoFrame.getNrows())
.collect(Collectors.toList());
copyResultFields(shapleyMojoFrame, outputRows);

ShapleyResponse contributions = new ShapleyResponse();
contributions.setContributions(outputRows);

List<String> outputFieldNames = new ArrayList<>(asList(shapleyMojoFrame.getColumnNames()));
contributions.setFields(outputFieldNames);

return contributions;
}

private static void copyFilteredInputFields(
ScoreRequest scoreRequest, Set<String> includedFields, List<Row> outputRows) {
if (includedFields.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ai.h2o.mojos.deploy.common.rest.model.Model;
import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest;
import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse;
import ai.h2o.mojos.deploy.common.rest.model.ShapleyResponse;
import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.lic.LicenseException;
Expand Down Expand Up @@ -32,6 +33,12 @@ public class MojoScorer {
private static final String MOJO_PIPELINE_PATH = System.getProperty(MOJO_PIPELINE_PATH_PROPERTY);
private static final MojoPipeline pipeline = loadMojoPipelineFromFile();

// note the mojo pipeline need to be reloaded here as we have a constrain from java mojo
// both SHAP values and predictions cannot be provided with the same pipeline
// Link: https://github.com/h2oai/mojo2/blob/7a1ab76b09f056334842a5b442ff89859aabf518/doc/shap.md
private static final MojoPipeline pipelineShapley = loadMojoPipelineFromFile();
Rajimut marked this conversation as resolved.
Show resolved Hide resolved


private final RequestToMojoFrameConverter requestConverter;
private final MojoFrameToResponseConverter responseConverter;
private final MojoPipelineToModelInfoConverter modelInfoConverter;
Expand All @@ -54,6 +61,7 @@ public MojoScorer(
this.responseConverter = responseConverter;
this.modelInfoConverter = modelInfoConverter;
this.csvConverter = csvConverter;
pipelineShapley.setShapPredictContrib(true);
}

/**
Expand All @@ -63,13 +71,42 @@ public MojoScorer(
* @return response {@link ScoreResponse}
*/
public ScoreResponse score(ScoreRequest request) {
return getScoreResponse(request, false);
}

/**
* Method to score an incoming request of type {@link ScoreRequest}
* with shapley contributions if needed.
*
* @param request {@link ScoreRequest}
* @return response {@link ScoreResponse}
*/
public ScoreResponse getScoreResponse(ScoreRequest request, Boolean shapleyResults) {
MojoFrame requestFrame = requestConverter.apply(request, pipeline.getInputFrameBuilder());
MojoFrame responseFrame = doScore(requestFrame);
MojoFrame responseFrame = doScore(requestFrame, false);
ScoreResponse response = responseConverter.apply(responseFrame, request);
response.id(pipeline.getUuid());

// set shapley contributions if requested
if (Boolean.TRUE.equals(shapleyResults)) {
response.setInputShapleyContributions(getShapleyResponse(request));
}
return response;
}

/**
* Method to get shapley values for an incoming request of type {@link ScoreRequest}.
*
* @param request {@link ScoreRequest}
* @return response {@link ShapleyResponse}
*/
private ShapleyResponse getShapleyResponse(ScoreRequest request) {
MojoFrame requestFrame = requestConverter
.apply(request, pipelineShapley.getInputFrameBuilder());
MojoFrame shapleyResponseFrame = doScore(requestFrame, true);
return responseConverter.getShapleyResponse(shapleyResponseFrame);
}

/**
* Method to score a csv file on path provided as part of request {@link ScoreRequest} payload.
*
Expand All @@ -82,7 +119,7 @@ public ScoreResponse scoreCsv(String csvFilePath) throws IOException {
try (InputStream csvStream = getInputStream(csvFilePath)) {
requestFrame = csvConverter.apply(csvStream, pipeline.getInputFrameBuilder());
}
MojoFrame responseFrame = doScore(requestFrame);
MojoFrame responseFrame = doScore(requestFrame, false);
ScoreResponse response = responseConverter.apply(responseFrame, new ScoreRequest());
response.id(pipeline.getUuid());
return response;
Expand All @@ -97,18 +134,23 @@ private static InputStream getInputStream(String filePath) throws IOException {
return new FileInputStream(filePath);
}

private static MojoFrame doScore(MojoFrame requestFrame) {
private static MojoFrame doScore(MojoFrame requestFrame, boolean isShapleyContribution) {
log.debug(
"Input has {} rows, {} columns: {}",
requestFrame.getNrows(),
requestFrame.getNcols(),
Arrays.toString(requestFrame.getColumnNames()));
MojoFrame responseFrame = pipeline.transform(requestFrame);
"Input has {} rows, {} columns: {}",
requestFrame.getNrows(),
requestFrame.getNcols(),
Arrays.toString(requestFrame.getColumnNames()));
MojoFrame responseFrame;
if (isShapleyContribution) {
responseFrame = pipelineShapley.transform(requestFrame);
} else {
responseFrame = pipeline.transform(requestFrame);
}
log.debug(
"Response has {} rows, {} columns: {}",
responseFrame.getNrows(),
responseFrame.getNcols(),
Arrays.toString(responseFrame.getColumnNames()));
"Response has {} rows, {} columns: {}",
responseFrame.getNrows(),
responseFrame.getNcols(),
Arrays.toString(responseFrame.getColumnNames()));
return responseFrame;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ public ResponseEntity<String> getModelId() {
}

@Override
public ResponseEntity<ScoreResponse> getScore(ScoreRequest request) {
public ResponseEntity<ScoreResponse> getScore(ScoreRequest request, Boolean shapleyResults) {
try {
log.info("Got scoring request");
return ResponseEntity.ok(scorer.score(request));
ScoreResponse scoreResponse = scorer.getScoreResponse(request, shapleyResults);
return ResponseEntity.ok(scoreResponse);
} catch (Exception e) {
log.info("Failed scoring request: {}, due to: {}", request, e.getMessage());
log.debug(" - failure cause: ", e);
Expand Down