diff --git a/src/jdk.crypto.ec/share/classes/sun/security/ec/ECDHKeyAgreement.java b/src/jdk.crypto.ec/share/classes/sun/security/ec/ECDHKeyAgreement.java
index 6c79228ef22..de98167be27 100644
--- a/src/jdk.crypto.ec/share/classes/sun/security/ec/ECDHKeyAgreement.java
+++ b/src/jdk.crypto.ec/share/classes/sun/security/ec/ECDHKeyAgreement.java
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2009, 2020, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2009, 2021, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -25,21 +25,33 @@
 
 package sun.security.ec;
 
-import java.math.*;
-import java.security.*;
-import java.security.interfaces.*;
-import java.security.spec.*;
-import java.util.Optional;
-
-import javax.crypto.*;
-import javax.crypto.spec.*;
-
+import sun.security.ec.point.AffinePoint;
+import sun.security.ec.point.Point;
 import sun.security.util.ArrayUtil;
 import sun.security.util.CurveDB;
-import sun.security.util.ECUtil;
 import sun.security.util.NamedCurve;
-import sun.security.util.math.*;
-import sun.security.ec.point.*;
+import sun.security.util.math.ImmutableIntegerModuloP;
+import sun.security.util.math.IntegerFieldModuloP;
+import sun.security.util.math.MutableIntegerModuloP;
+import sun.security.util.math.SmallValue;
+
+import javax.crypto.KeyAgreementSpi;
+import javax.crypto.SecretKey;
+import javax.crypto.ShortBufferException;
+import javax.crypto.spec.SecretKeySpec;
+import java.math.BigInteger;
+import java.security.InvalidAlgorithmParameterException;
+import java.security.InvalidKeyException;
+import java.security.Key;
+import java.security.NoSuchAlgorithmException;
+import java.security.PrivateKey;
+import java.security.SecureRandom;
+import java.security.interfaces.ECPrivateKey;
+import java.security.interfaces.ECPublicKey;
+import java.security.spec.AlgorithmParameterSpec;
+import java.security.spec.ECParameterSpec;
+import java.security.spec.EllipticCurve;
+import java.util.Optional;
 
 /**
  * KeyAgreement implementation for ECDH.
@@ -50,6 +62,7 @@ public final class ECDHKeyAgreement extends KeyAgreementSpi {
 
     // private key, if initialized
     private ECPrivateKey privateKey;
+    ECOperations privateKeyOps;
 
     // public key, non-null between doPhase() & generateSecret() only
     private ECPublicKey publicKey;
@@ -63,16 +76,34 @@ public final class ECDHKeyAgreement extends KeyAgreementSpi {
     public ECDHKeyAgreement() {
     }
 
+    // Generic init
+    private void init(Key key) throws
+        InvalidKeyException, InvalidAlgorithmParameterException {
+        if (!(key instanceof PrivateKey)) {
+            throw new InvalidKeyException("Key must be instance of PrivateKey");
+        }
+        privateKey = (ECPrivateKey)ECKeyFactory.toECKey(key);
+        publicKey = null;
+        Optional<ECOperations> opsOpt =
+            ECOperations.forParameters(privateKey.getParams());
+        if (opsOpt.isEmpty()) {
+            NamedCurve nc = CurveDB.lookup(privateKey.getParams());
+            throw new InvalidAlgorithmParameterException(
+                "Curve not supported: " + (nc != null ? nc.toString() :
+                    "unknown"));
+        }
+        privateKeyOps = opsOpt.get();
+    }
+
     // see JCE spec
     @Override
     protected void engineInit(Key key, SecureRandom random)
             throws InvalidKeyException {
-        if (!(key instanceof PrivateKey)) {
-            throw new InvalidKeyException
-                        ("Key must be instance of PrivateKey");
+        try {
+            init(key);
+        } catch (InvalidAlgorithmParameterException e) {
+            throw new InvalidKeyException(e);
         }
-        privateKey = (ECPrivateKey) ECKeyFactory.toECKey(key);
-        publicKey = null;
     }
 
     // see JCE spec
@@ -84,7 +115,7 @@ public final class ECDHKeyAgreement extends KeyAgreementSpi {
             throw new InvalidAlgorithmParameterException
                         ("Parameters not supported");
         }
-        engineInit(key, random);
+        init(key);
     }
 
     // see JCE spec
@@ -108,28 +139,34 @@ public final class ECDHKeyAgreement extends KeyAgreementSpi {
 
         this.publicKey = (ECPublicKey) key;
 
-        ECParameterSpec params = publicKey.getParams();
-        int keyLenBits = params.getCurve().getField().getFieldSize();
+        int keyLenBits =
+            publicKey.getParams().getCurve().getField().getFieldSize();
         secretLen = (keyLenBits + 7) >> 3;
 
+        // Validate public key
+        validate(privateKeyOps, publicKey);
+
         return null;
     }
 
-    private static void validateCoordinate(BigInteger c, BigInteger mod) {
+    private static void validateCoordinate(BigInteger c, BigInteger mod)
+        throws InvalidKeyException{
         if (c.compareTo(BigInteger.ZERO) < 0) {
-            throw new ProviderException("invalid coordinate");
+            throw new InvalidKeyException("Invalid coordinate");
         }
 
         if (c.compareTo(mod) >= 0) {
-            throw new ProviderException("invalid coordinate");
+            throw new InvalidKeyException("Invalid coordinate");
         }
     }
 
     /*
-     * Check whether a public key is valid. Throw ProviderException
-     * if it is not valid or could not be validated.
+     * Check whether a public key is valid.
      */
-    private static void validate(ECOperations ops, ECPublicKey key) {
+    private static void validate(ECOperations ops, ECPublicKey key)
+        throws InvalidKeyException {
+
+        ECParameterSpec spec = key.getParams();
 
         // ensure that integers are in proper range
         BigInteger x = key.getW().getAffineX();
@@ -140,23 +177,23 @@ public final class ECDHKeyAgreement extends KeyAgreementSpi {
         validateCoordinate(y, p);
 
         // ensure the point is on the curve
-        EllipticCurve curve = key.getParams().getCurve();
+        EllipticCurve curve = spec.getCurve();
         BigInteger rhs = x.modPow(BigInteger.valueOf(3), p).add(curve.getA()
             .multiply(x)).add(curve.getB()).mod(p);
         BigInteger lhs = y.modPow(BigInteger.valueOf(2), p).mod(p);
         if (!rhs.equals(lhs)) {
-            throw new ProviderException("point is not on curve");
+            throw new InvalidKeyException("Point is not on curve");
         }
 
         // check the order of the point
         ImmutableIntegerModuloP xElem = ops.getField().getElement(x);
         ImmutableIntegerModuloP yElem = ops.getField().getElement(y);
         AffinePoint affP = new AffinePoint(xElem, yElem);
-        byte[] order = key.getParams().getOrder().toByteArray();
+        byte[] order = spec.getOrder().toByteArray();
         ArrayUtil.reverse(order);
         Point product = ops.multiply(affP, order);
         if (!ops.isNeutral(product)) {
-            throw new ProviderException("point has incorrect order");
+            throw new InvalidKeyException("Point has incorrect order");
         }
 
     }
@@ -167,15 +204,13 @@ public final class ECDHKeyAgreement extends KeyAgreementSpi {
         if ((privateKey == null) || (publicKey == null)) {
             throw new IllegalStateException("Not initialized correctly");
         }
+
         byte[] result;
-        Optional<byte[]> resultOpt = deriveKeyImpl(privateKey, publicKey);
-        if (resultOpt.isEmpty()) {
-            NamedCurve nc = CurveDB.lookup(publicKey.getParams());
-            throw new IllegalStateException(
-                new InvalidAlgorithmParameterException("Curve not supported: " +
-                    (nc != null ? nc.toString() : "unknown")));
+        try {
+            result = deriveKeyImpl(privateKey, privateKeyOps, publicKey);
+        } catch (Exception e) {
+            throw new IllegalStateException(e);
         }
-        result = resultOpt.get();
         publicKey = null;
         return result;
     }
@@ -210,48 +245,30 @@ public final class ECDHKeyAgreement extends KeyAgreementSpi {
     }
 
     private static
-    Optional<byte[]> deriveKeyImpl(ECPrivateKey priv, ECPublicKey pubKey) {
-
-        ECParameterSpec ecSpec = priv.getParams();
-        EllipticCurve curve = ecSpec.getCurve();
-        Optional<ECOperations> opsOpt = ECOperations.forParameters(ecSpec);
-        if (opsOpt.isEmpty()) {
-            return Optional.empty();
-        }
-        ECOperations ops = opsOpt.get();
-        if (! (priv instanceof ECPrivateKeyImpl)) {
-            return Optional.empty();
-        }
-        ECPrivateKeyImpl privImpl = (ECPrivateKeyImpl) priv;
-        byte[] sArr = privImpl.getArrayS();
-
-        // to match the native implementation, validate the public key here
-        // and throw ProviderException if it is invalid
-        validate(ops, pubKey);
+    byte[] deriveKeyImpl(ECPrivateKey priv, ECOperations ops,
+        ECPublicKey pubKey) throws InvalidKeyException {
 
         IntegerFieldModuloP field = ops.getField();
         // convert s array into field element and multiply by the cofactor
-        MutableIntegerModuloP scalar = field.getElement(sArr).mutable();
+        MutableIntegerModuloP scalar = field.getElement(priv.getS()).mutable();
         SmallValue cofactor =
             field.getSmallValue(priv.getParams().getCofactor());
         scalar.setProduct(cofactor);
-        int keySize = (curve.getField().getFieldSize() + 7) / 8;
-        byte[] privArr = scalar.asByteArray(keySize);
-
+        int keySize =
+            (priv.getParams().getCurve().getField().getFieldSize() + 7) / 8;
         ImmutableIntegerModuloP x =
             field.getElement(pubKey.getW().getAffineX());
         ImmutableIntegerModuloP y =
             field.getElement(pubKey.getW().getAffineY());
-        AffinePoint affPub = new AffinePoint(x, y);
-        Point product = ops.multiply(affPub, privArr);
+        Point product = ops.multiply(new AffinePoint(x, y),
+            scalar.asByteArray(keySize));
         if (ops.isNeutral(product)) {
-            throw new ProviderException("Product is zero");
+            throw new InvalidKeyException("Product is zero");
         }
-        AffinePoint affProduct = product.asAffine();
 
-        byte[] result = affProduct.getX().asByteArray(keySize);
+        byte[] result = product.asAffine().getX().asByteArray(keySize);
         ArrayUtil.reverse(result);
 
-        return Optional.of(result);
+        return result;
     }
 }
diff --git a/test/jdk/com/sun/crypto/provider/KeyAgreement/ECKeyCheck.java b/test/jdk/com/sun/crypto/provider/KeyAgreement/ECKeyCheck.java
new file mode 100644
index 00000000000..e17d32c6bfd
--- /dev/null
+++ b/test/jdk/com/sun/crypto/provider/KeyAgreement/ECKeyCheck.java
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @bug 8261502
+ * @summary Check that ECPrivateKey's that are not ECPrivateKeyImpl can use
+ * ECDHKeyAgreement
+ */
+
+import javax.crypto.KeyAgreement;
+import java.math.BigInteger;
+import java.security.KeyPairGenerator;
+import java.security.interfaces.ECPrivateKey;
+import java.security.interfaces.ECPublicKey;
+import java.security.spec.ECGenParameterSpec;
+import java.security.spec.ECParameterSpec;
+
+public class ECKeyCheck {
+
+    public static final void main(String args[]) throws Exception {
+        ECGenParameterSpec spec = new ECGenParameterSpec("secp256r1");
+        KeyPairGenerator kpg = KeyPairGenerator.getInstance("EC");
+        kpg.initialize(spec);
+
+        ECPrivateKey privKey = (ECPrivateKey) kpg.generateKeyPair().getPrivate();
+        ECPublicKey pubKey = (ECPublicKey) kpg.generateKeyPair().getPublic();
+        generateECDHSecret(privKey, pubKey);
+        generateECDHSecret(new newPrivateKeyImpl(privKey), pubKey);
+    }
+
+    private static byte[] generateECDHSecret(ECPrivateKey privKey,
+        ECPublicKey pubKey) throws Exception {
+        KeyAgreement ka = KeyAgreement.getInstance("ECDH");
+        ka.init(privKey);
+        ka.doPhase(pubKey, true);
+        return ka.generateSecret();
+    }
+
+    // Test ECPrivateKey class
+    private static class newPrivateKeyImpl implements ECPrivateKey {
+        private ECPrivateKey p;
+
+        newPrivateKeyImpl(ECPrivateKey p) {
+            this.p = p;
+        }
+
+        public BigInteger getS() {
+            return p.getS();
+        }
+
+        public byte[] getEncoded() {
+            return p.getEncoded();
+        }
+
+        public String getFormat() {
+            return p.getFormat();
+        }
+
+        public String getAlgorithm() {
+            return p.getAlgorithm();
+        }
+
+        public ECParameterSpec getParams() {
+            return p.getParams();
+        }
+    }
+}