Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added shapley values to rest scorer endpoint #241

Merged
merged 21 commits into from
Aug 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest;
import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse;
import ai.h2o.mojos.deploy.common.transform.MojoFrameToResponseConverter;
import ai.h2o.mojos.deploy.common.transform.MojoFrameToScoreResponseConverter;
import ai.h2o.mojos.deploy.common.transform.RequestChecker;
import ai.h2o.mojos.deploy.common.transform.RequestToMojoFrameConverter;
import ai.h2o.mojos.deploy.common.transform.SampleRequestBuilder;
import ai.h2o.mojos.deploy.common.transform.ScoreRequestFormatException;
import ai.h2o.mojos.deploy.common.transform.ScoreRequestToMojoFrameConverter;
import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.lic.LicenseException;
Expand Down Expand Up @@ -35,8 +35,10 @@ public final class MojoScorer {
private static final Object pipelineLock = new Object();
private static MojoPipeline pipeline;

private final RequestToMojoFrameConverter requestConverter = new RequestToMojoFrameConverter();
private final MojoFrameToResponseConverter responseConverter = new MojoFrameToResponseConverter();
private final ScoreRequestToMojoFrameConverter requestConverter
= new ScoreRequestToMojoFrameConverter();
private final MojoFrameToScoreResponseConverter responseConverter
= new MojoFrameToScoreResponseConverter();
private final RequestChecker requestChecker = new RequestChecker(new SampleRequestBuilder());

/** Processes a single {@link ScoreRequest} in the given AWS Lambda {@link Context}. */
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
package ai.h2o.mojos.deploy.sagemaker.hosted.config;

import ai.h2o.mojos.deploy.common.transform.ContributionRequestToMojoFrameConverter;
import ai.h2o.mojos.deploy.common.transform.CsvToMojoFrameConverter;
import ai.h2o.mojos.deploy.common.transform.MojoFrameToResponseConverter;
import ai.h2o.mojos.deploy.common.transform.MojoFrameToContributionResponseConverter;
import ai.h2o.mojos.deploy.common.transform.MojoFrameToScoreResponseConverter;
import ai.h2o.mojos.deploy.common.transform.MojoPipelineToModelInfoConverter;
import ai.h2o.mojos.deploy.common.transform.MojoScorer;
import ai.h2o.mojos.deploy.common.transform.RequestToMojoFrameConverter;
import ai.h2o.mojos.deploy.common.transform.SampleRequestBuilder;
import ai.h2o.mojos.deploy.common.transform.ScoreRequestToMojoFrameConverter;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
class ScorerConfiguration {
@Bean
public MojoFrameToResponseConverter responseConverter() {
return new MojoFrameToResponseConverter();
public MojoFrameToScoreResponseConverter responseConverter() {
return new MojoFrameToScoreResponseConverter();
}

@Bean
public RequestToMojoFrameConverter requestConverter() {
return new RequestToMojoFrameConverter();
public ScoreRequestToMojoFrameConverter requestConverter() {
return new ScoreRequestToMojoFrameConverter();
}

@Bean
public ContributionRequestToMojoFrameConverter contributionRequestConverter() {
return new ContributionRequestToMojoFrameConverter();
}

@Bean
public MojoFrameToContributionResponseConverter contributionResponseConverter() {
return new MojoFrameToContributionResponseConverter();
}

@Bean
Expand All @@ -38,10 +50,18 @@ public SampleRequestBuilder sampleRequestBuilder() {

@Bean
public MojoScorer mojoScorer(
RequestToMojoFrameConverter requestConverter,
MojoFrameToResponseConverter responseConverter,
ScoreRequestToMojoFrameConverter requestConverter,
MojoFrameToScoreResponseConverter responseConverter,
ContributionRequestToMojoFrameConverter contributionRequestConverter,
MojoFrameToContributionResponseConverter contributionResponseConverter,
MojoPipelineToModelInfoConverter modelInfoConverter,
CsvToMojoFrameConverter csvConverter) {
return new MojoScorer(requestConverter, responseConverter, modelInfoConverter, csvConverter);
return new MojoScorer(
requestConverter,
responseConverter,
contributionRequestConverter,
contributionResponseConverter,
modelInfoConverter,
csvConverter);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package ai.h2o.mojos.deploy.sagemaker.hosted.controller;

import ai.h2o.mojos.deploy.common.rest.api.ModelApi;
import ai.h2o.mojos.deploy.common.rest.model.ContributionRequest;
import ai.h2o.mojos.deploy.common.rest.model.ContributionResponse;
import ai.h2o.mojos.deploy.common.rest.model.Model;
import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest;
import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse;
Expand All @@ -11,13 +13,16 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;

@Controller
public class ModelsApiController implements ModelApi {

private static final String UNIMPLEMENTED_MESSAGE
= "Shapley values are not implemented yet";
private static final Logger log = LoggerFactory.getLogger(ModelsApiController.class);

private final MojoScorer scorer;
Expand Down Expand Up @@ -84,6 +89,14 @@ public ResponseEntity<ScoreResponse> getScoreByFile(String file) {
}
}

@Override
public ResponseEntity<ContributionResponse> getContribution(
ContributionRequest request) {
// TODO: to be implemented in the future
log.info(" Unsupported operation: " + UNIMPLEMENTED_MESSAGE);
return ResponseEntity.status(HttpStatus.NOT_IMPLEMENTED).build();
}

@Override
public ResponseEntity<ScoreRequest> getSampleRequest() {
return ResponseEntity.ok(sampleRequestBuilder.build(scorer.getPipeline().getInputMeta()));
Expand Down
106 changes: 101 additions & 5 deletions common/swagger/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ info:
description: >-
This is a definition of the REST API for scoring from H2O. This API is
intended to be used within DAI and eventually across all H2O scoring systems
version: 1.0.0
version: 1.1.0
title: Scoring API - v1
termsOfService: ''
contact:
Expand Down Expand Up @@ -107,6 +107,32 @@ paths:
$ref: '#/definitions/ScoreResponse'
'400':
description: Invalid payload
'/model/contribution':
post:
tags:
- contribution
summary: Contribution score or Shapley values on given rows
description: Computes contribution score with the rows sent in the body of the post request
operationId: getContribution
consumes:
- application/json
produces:
- application/json
parameters:
- name: payload
in: body
required: true
schema:
$ref: '#/definitions/ContributionRequest'
responses:
'200':
description: Successful operation
schema:
$ref: '#/definitions/ContributionResponse'
'501':
description: Implementation not supported
'400':
description: Invalid payload
securityDefinitions:
api_key:
type: apiKey
Expand All @@ -127,10 +153,7 @@ definitions:
properties:
scoringType:
type: string
enum:
- REGRESSION
- CLASSIFICATION
- BINOMIAL
$ref: '#/definitions/ScoringType'
scoringResponLabels:
type: array
items:
Expand All @@ -141,9 +164,69 @@ definitions:
type: array
items:
type: string

ContributionRequest:
type: object
required:
- requestShapleyValueType
properties:
requestShapleyValueType:
description: >
The string to say what type of Shap values are needed.
`ORIGINAL` implies Shap values of original features are requested,
Rajimut marked this conversation as resolved.
Show resolved Hide resolved
`TRANSFORMED` implies that Shap values of transformed features are requested.
$ref: '#/definitions/ShapleyType'
fields:
description: >
An array holding the names of fields in the order of appearance in the `rows` property. The length of `fields`
has to match length of each row in `rows`. No duplicates are allowed.
type: array
items:
type: string
rows:
description: >
An array of rows consisting the actual input data for scoring, one scoring request per row.
type: array
items:
$ref: '#/definitions/Row'
ContributionResponse:
type: object
properties:
features:
description: >
An array holding the names of fields in the order of appearance in the rows of the `contributions` property.
type: array
items:
type: string
contributionGroups:
description: >
An array of rows consisting of the shapley contributions output corresponding to an output group.
type: array
items:
$ref: '#/definitions/ContributionGroup'
ContributionGroup:
type: object
properties:
outputGroup:
description: >
Name of the output group. It will be populated only for multinomial models.
Shapley values are not supported for third party models yet, hence this field will not be populated.
type: string
contributions:
description: >
An array of rows consisting of the shapley contributions output corresponding to columns in the fields
type: array
items:
$ref: '#/definitions/Row'
ScoreRequest:
type: object
properties:
requestShapleyValueType:
description: >
The string to say what type of Shap values are needed.
`ORIGINAL` implies Shap values of original features are requested,
`TRANSFORMED` implies that Shap values of transformed features are requested.
$ref: '#/definitions/ShapleyType'
includeFieldsInOutput:
description: >
An array holding the list of field names to be copied from the input request row to the corresponding scoring
Expand Down Expand Up @@ -203,6 +286,13 @@ definitions:
type: array
items:
$ref: '#/definitions/Row'
featureShapleyContributions:
type: object
description: >
An object with features and shapley values that was requested by the client.
It is currently available for transformed features of binomial, regression and multinomial models of mojo2.
This field will not be populated if the Shapley values are not available for a model.
$ref: '#/definitions/ContributionResponse'
DataField:
type: object
properties:
Expand Down Expand Up @@ -235,3 +325,9 @@ definitions:
type: array
items:
$ref: '#/definitions/DataField'
ShapleyType:
type: string
enum: [ ORIGINAL, TRANSFORMED, NONE ]
ScoringType:
type: string
enum: [ REGRESSION, CLASSIFICATION, BINOMIAL ]
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package ai.h2o.mojos.deploy.common.transform;

import ai.h2o.mojos.deploy.common.rest.model.ContributionRequest;
import ai.h2o.mojos.deploy.common.rest.model.Row;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameBuilder;
import ai.h2o.mojos.runtime.frame.MojoRowBuilder;
import java.util.List;
import java.util.function.BiFunction;

/**
* Converts the original API request object
* {@link ContributionRequest} into the input {@link MojoFrame}.
*/
public class ContributionRequestToMojoFrameConverter
implements BiFunction<ContributionRequest, MojoFrameBuilder, MojoFrame> {
@Override
public MojoFrame apply(ContributionRequest scoreRequest, MojoFrameBuilder frameBuilder) {
List<String> fields = scoreRequest.getFields();
if (scoreRequest.getRows() != null) {
for (Row row : scoreRequest.getRows()) {
MojoRowBuilder rowBuilder = frameBuilder.getMojoRowBuilder();
for (int i = 0; i < row.size(); i++) {
rowBuilder.setValue(fields.get(i), row.get(i));
}
frameBuilder.addRow(rowBuilder);
}
}

return frameBuilder.toMojoFrame();
}
}
Loading