diff --git a/src/main/scala/uk/org/floop/sparqlTestRunner/Run.scala b/src/main/scala/uk/org/floop/sparqlTestRunner/Run.scala index b15944d..66464ba 100644 --- a/src/main/scala/uk/org/floop/sparqlTestRunner/Run.scala +++ b/src/main/scala/uk/org/floop/sparqlTestRunner/Run.scala @@ -20,12 +20,15 @@ import java.net.URI import java.nio.charset.StandardCharsets import java.nio.file.{Files, Path} +import java.util +import org.apache.http.HttpHeaders import org.apache.http.auth.{AuthScope, UsernamePasswordCredentials} import org.apache.http.client.protocol.HttpClientContext import org.apache.http.client.utils.URIUtils import org.apache.http.impl.auth.BasicScheme import org.apache.http.impl.client.{BasicAuthCache, BasicCredentialsProvider, HttpClients} +import org.apache.http.message.BasicHeader import org.apache.jena.query._ import org.apache.jena.rdf.model.ModelFactory import org.apache.jena.riot.RDFDataMgr @@ -36,13 +39,13 @@ case class Config(dir: File = new File("tests/sparql"), report: File = new File("reports/TESTS-sparql-test-runner.xml"), - ignoreFail: Boolean = false, endpoint: Option[URI] = None, auth: Option[String] = None, + ignoreFail: Boolean = false, endpoint: Option[URI] = None, auth: Option[Either[String, String]] = None, params: Map[String, String] = Map.empty, data: Seq[File] = Seq()) object Run extends App { val packageVersion: String = getClass.getPackage.getImplementationVersion val parser = new scopt.OptionParser[Config]("sparql-testrunner") { - head("sparql-testrunner", packageVersion) + head("sparql-test-runner", packageVersion) opt[File]('t', "testdir") optional() valueName "" action { (x, c) => c.copy(dir = x) } text "location of SPARQL queries to run, defaults to tests/sparql" @@ -56,8 +59,11 @@ c.copy(endpoint = Some(x)) } text "SPARQL endpoint to run the queries against" opt[String]('a', "auth") optional() valueName "" action { (x, c) => - c.copy(auth = Some(x)) + c.copy(auth = Some(Left(x))) } text "basic authentication username:password" + opt[String]('k', "token") optional() valueName "" action { (x, c) => + c.copy(auth = Some(Right(x))) + } text "oAuth token" opt[Map[String,String]]('p', name="param") optional() valueName "l=\"somelabel\"@en,n=" action { (x, c) => c.copy(params = x) } text "variables to replace in query" @@ -80,7 +86,7 @@ // Querying a remote endpoint; if authentication is required, need to set up pre-emptive auth, // see https://hc.apache.org/httpcomponents-client-ga/tutorial/html/authentication.html config.auth match { - case Some(userpass) => + case Some(Left(userpass)) => val target = URIUtils.extractHost(uri) // new HttpHost(uri.getHost, uri.getPort) val credsProvider = new BasicCredentialsProvider() credsProvider.setCredentials( @@ -93,6 +99,12 @@ context.setAuthCache(authCache) val client = HttpClients.custom.build() (query: Query) => QueryExecutionFactory.sparqlService(uri.toString, query, client, context) + case Some(Right(token)) => + val authHeader = new BasicHeader(HttpHeaders.AUTHORIZATION, "Bearer " + token) + val headers = new util.ArrayList[BasicHeader] + headers.add(authHeader) + val client = HttpClients.custom.setDefaultHeaders(headers).build() + (query: Query) => QueryExecutionFactory.sparqlService(uri.toString, query, client) case None => (query: Query) => QueryExecutionFactory.sparqlService(uri.toString, query) }