Skip to content

Commit

Permalink
feat: isClassDeclared prepareInvocations, fix cairo0 test (#1211)
Browse files Browse the repository at this point in the history
* feat: fix cairo0 test, feat provider.isClassDeclared, feat provider.createBulkInvocations

* chore: rename ContractIdentifier to ContractClassIdentifier

* chore: contractClassIdentifier fix

* chore: prepareInvocations, LibraryError
  • Loading branch information
tabaktoni authored Sep 3, 2024
1 parent e435903 commit 9fdf54f
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 46 deletions.
142 changes: 107 additions & 35 deletions __tests__/account.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,10 @@ describe('deploy and test Wallet', () => {
const calldata = { publicKey: pubKey };

// declare account
const declareAccount = await account.declare({
const declareAccount = await account.declareIfNot({
contract: compiledOpenZeppelinAccount,
});
const accountClassHash = declareAccount.class_hash;
await account.waitForTransaction(declareAccount.transaction_hash);

// fund new account
const tobeAccountAddress = hash.calculateContractAddressFromHash(
Expand Down Expand Up @@ -193,6 +192,9 @@ describe('deploy and test Wallet', () => {
});

describe('simulate transaction - single transaction S0.11.2', () => {
test('simulate empty invocations', async () => {
await expect(account.simulateTransaction([])).rejects.toThrow(TypeError);
});
test('simulate INVOKE Cairo 0', async () => {
const res = await account.simulateTransaction([
{
Expand Down Expand Up @@ -245,37 +247,30 @@ describe('deploy and test Wallet', () => {

describeIfDevnet('declare tests only on devnet', () => {
test('simulate DECLARE - Cairo 0 Contract', async () => {
const res = await account.simulateTransaction([
const invocation = await provider.prepareInvocations([
{
type: TransactionType.DECLARE,
contract: compiledErc20,
},
]);
expect(res).toMatchSchemaRef('SimulateTransactionResponse');
if (invocation.length) {
const res = await account.simulateTransaction(invocation);
expect(res).toMatchSchemaRef('SimulateTransactionResponse');
}
});
});

test('simulate DECLARE - Cairo 1 Contract - test if not already declared', async () => {
const declareContractPayload = extractContractHashes({
contract: compiledHelloSierra,
casm: compiledHelloSierraCasm,
});
let skip = false;
try {
await account.getClassByHash(declareContractPayload.classHash);
skip = true;
} catch (error) {
/* empty */
}
const invocation = await provider.prepareInvocations([
{
type: TransactionType.DECLARE,
contract: compiledHelloSierra,
casm: compiledHelloSierraCasm,
},
]);

if (!skip) {
const res = await account.simulateTransaction([
{
type: TransactionType.DECLARE,
contract: compiledHelloSierra,
casm: compiledHelloSierraCasm,
},
]);
if (invocation.length) {
const res = await account.simulateTransaction(invocation);
expect(res).toMatchSchemaRef('SimulateTransactionResponse');
}
});
Expand Down Expand Up @@ -625,6 +620,10 @@ describe('deploy and test Wallet', () => {
expect(result).toMatchSchemaRef('EstimateFee');
});

test('estimate fee bulk on empty invocations', async () => {
await expect(account.estimateFeeBulk([])).rejects.toThrow(TypeError);
});

test('estimate fee bulk invoke functions', async () => {
// TODO @dhruvkelawala check expectation for feeTransactionVersion
// const innerInvokeEstFeeSpy = jest.spyOn(account.signer, 'signTransaction');
Expand Down Expand Up @@ -696,22 +695,80 @@ describe('deploy and test Wallet', () => {
});

describeIfDevnet('declare tests only on devnet', () => {
test('declare, deploy & multi invoke functions', async () => {
const res = await account.estimateFeeBulk([
/* {
// Cairo 1.1.0, if declared estimate error with can't redeclare same contract
type: TransactionType.DECLARE,
contract: compiledHelloSierra,
casm: compiledHelloSierraCasm,
}, */
test('Manual: declare, deploy & multi invoke functions', async () => {
/*
* For Cairo0 and Cairo1 contracts re-declaration of the class throw an errors
* as soo We first need to test is class is already declared
*/
const isDeclaredCairo0 = await account.isClassDeclared({
classHash: '0x54328a1075b8820eb43caf0caa233923148c983742402dcfc38541dd843d01a',
});

const hashes = extractContractHashes({
contract: compiledHelloSierra,
casm: compiledHelloSierraCasm,
});

const isDeclaredCairo1 = await account.isClassDeclared({ classHash: hashes.classHash });

const invocations = [
{
// Cairo 0
type: TransactionType.DECLARE,
type: TransactionType.INVOKE,
payload: [
{
contractAddress: erc20Address,
entrypoint: 'approve',
calldata: {
address: erc20Address,
amount: uint256(10),
},
},
{
contractAddress: erc20Address,
entrypoint: 'transfer',
calldata: [erc20.address, '10', '0'],
},
],
},
{
type: TransactionType.DEPLOY,
payload: {
contract: compiledErc20,
classHash: '0x54328a1075b8820eb43caf0caa233923148c983742402dcfc38541dd843d01a',
constructorCalldata: ['Token', 'ERC20', account.address],
},
},
...(!isDeclaredCairo0
? [
{
// Cairo 0
type: TransactionType.DECLARE,
payload: {
contract: compiledErc20,
classHash: '0x54328a1075b8820eb43caf0caa233923148c983742402dcfc38541dd843d01a',
},
},
]
: []),
...(!isDeclaredCairo1
? [
{
// Cairo 1.1.0, if declared estimate error with can't redeclare same contract
type: TransactionType.DECLARE,
contract: compiledHelloSierra,
casm: compiledHelloSierraCasm,
},
]
: []),
];

const res = await account.estimateFeeBulk(invocations);
res.forEach((value) => {
expect(value).toMatchSchemaRef('EstimateFee');
});
});

test('prepareInvocations: unordered declare, deploy & multi invoke', async () => {
const invocations = await provider.prepareInvocations([
{
type: TransactionType.DEPLOY,
payload: {
Expand All @@ -737,8 +794,23 @@ describe('deploy and test Wallet', () => {
},
],
},
{
// Cairo 0
type: TransactionType.DECLARE,
payload: {
contract: compiledErc20,
classHash: '0x54328a1075b8820eb43caf0caa233923148c983742402dcfc38541dd843d01a',
},
},
{
// Cairo 1.1.0, if declared estimate error with can't redeclare same contract
type: TransactionType.DECLARE,
contract: compiledHelloSierra,
casm: compiledHelloSierraCasm,
},
]);
expect(res).toHaveLength(3);

const res = await account.estimateFeeBulk(invocations);
res.forEach((value) => {
expect(value).toMatchSchemaRef('EstimateFee');
});
Expand Down
13 changes: 8 additions & 5 deletions src/account/default.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// eslint-disable-next-line @typescript-eslint/no-unused-vars
import type { SPEC } from 'starknet-types-07';

import {
OutsideExecutionCallerAny,
SNIP9_V1_INTERFACE_ID,
Expand Down Expand Up @@ -53,16 +54,16 @@ import {
type OutsideExecutionOptions,
type OutsideTransaction,
} from '../types/outsideExecution';
import {
buildExecuteFromOutsideCallData,
getOutsideCall,
getTypedData,
} from '../utils/outsideExecution';
import { CallData } from '../utils/calldata';
import { extractContractHashes, isSierra } from '../utils/contract';
import { parseUDCEvent } from '../utils/events';
import { calculateContractAddressFromHash } from '../utils/hash';
import { isHex, toBigInt, toCairoBool, toHex } from '../utils/num';
import {
buildExecuteFromOutsideCallData,
getOutsideCall,
getTypedData,
} from '../utils/outsideExecution';
import { parseContract } from '../utils/provider';
import { isString } from '../utils/shortString';
import { supportsInterface } from '../utils/src5';
Expand Down Expand Up @@ -279,6 +280,7 @@ export class Account extends Provider implements AccountInterface {
invocations: Invocations,
details: UniversalDetails = {}
): Promise<EstimateFeeBulk> {
if (!invocations.length) throw TypeError('Invocations should be non-empty array');
const { nonce, blockIdentifier, version, skipValidate } = details;
const accountInvocations = await this.accountInvocationsFactory(invocations, {
...v3Details(details),
Expand All @@ -304,6 +306,7 @@ export class Account extends Provider implements AccountInterface {
invocations: Invocations,
details: SimulateTransactionDetails = {}
): Promise<SimulateTransactionResponse> {
if (!invocations.length) throw TypeError('Invocations should be non-empty array');
const { nonce, blockIdentifier, skipValidate = true, skipExecute, version } = details;
const accountInvocations = await this.accountInvocationsFactory(invocations, {
...v3Details(details),
Expand Down
73 changes: 67 additions & 6 deletions src/provider/rpc.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { SPEC } from 'starknet-types-07';

import { RPC06, RPC07, RpcChannel } from '../channel';
import {
AccountInvocations,
Expand All @@ -8,12 +9,14 @@ import {
BlockTag,
Call,
ContractClassResponse,
ContractClassIdentifier,
ContractVersion,
DeclareContractTransaction,
DeployAccountContractTransaction,
GetBlockResponse,
GetTxReceiptResponseWithoutHelper,
Invocation,
Invocations,
InvocationsDetailsWithNonce,
PendingBlock,
PendingStateUpdate,
Expand All @@ -25,24 +28,24 @@ import {
getContractVersionOptions,
getEstimateFeeBulkOptions,
getSimulateTransactionOptions,
waitForTransactionOptions,
type Signature,
type TypedData,
waitForTransactionOptions,
} from '../types';
import type { TransactionWithHash } from '../types/provider/spec';
import assert from '../utils/assert';
import { CallData } from '../utils/calldata';
import { getAbiContractVersion } from '../utils/calldata/cairo';
import { isSierra } from '../utils/contract';
import { extractContractHashes, isSierra } from '../utils/contract';
import { solidityUint256PackedKeccak256 } from '../utils/hash';
import { isBigNumberish, toBigInt, toHex } from '../utils/num';
import { wait } from '../utils/provider';
import { RPCResponseParser } from '../utils/responseParser/rpc';
import { formatSignature } from '../utils/stark';
import { GetTransactionReceiptResponse, ReceiptTx } from '../utils/transactionReceipt';
import { getMessageHash, validateTypedData } from '../utils/typedData';
import { LibraryError } from './errors';
import { ProviderInterface } from './interface';
import { solidityUint256PackedKeccak256 } from '../utils/hash';
import { CallData } from '../utils/calldata';
import { formatSignature } from '../utils/stark';
import { getMessageHash, validateTypedData } from '../utils/typedData';

export class RpcProvider implements ProviderInterface {
public responseParser: RPCResponseParser;
Expand Down Expand Up @@ -570,4 +573,62 @@ export class RpcProvider implements ProviderInterface {

throw Error(`Signature verification Error: ${error}`);
}

/**
* Test if class is already declared from ContractClassIdentifier
* Helper method using getClass
* @param ContractClassIdentifier
* @param blockIdentifier
*/
public async isClassDeclared(
contractClassIdentifier: ContractClassIdentifier,
blockIdentifier?: BlockIdentifier
) {
let classHash: string;
if (!contractClassIdentifier.classHash && 'contract' in contractClassIdentifier) {
const hashes = extractContractHashes(contractClassIdentifier);
classHash = hashes.classHash;
} else if (contractClassIdentifier.classHash) {
classHash = contractClassIdentifier.classHash;
} else {
throw Error('contractClassIdentifier type not satisfied');
}

try {
const result = await this.getClass(classHash, blockIdentifier);
return result instanceof Object;
} catch (error) {
if (error instanceof LibraryError) {
return false;
}
throw error;
}
}

/**
* Build bulk invocations with auto-detect declared class
* 1. Test if class is declared if not declare it preventing already declared class error and not declared class errors
* 2. Order declarations first
* @param invocations
*/
public async prepareInvocations(invocations: Invocations) {
const bulk: Invocations = [];
// Build new ordered array
// eslint-disable-next-line no-restricted-syntax
for (const invocation of invocations) {
if (invocation.type === TransactionType.DECLARE) {
// Test if already declared
// eslint-disable-next-line no-await-in-loop
const isDeclared = await this.isClassDeclared(
'payload' in invocation ? invocation.payload : invocation
);
if (!isDeclared) {
bulk.unshift(invocation);
}
} else {
bulk.push(invocation);
}
}
return bulk;
}
}
6 changes: 6 additions & 0 deletions src/types/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { StarknetChainId } from '../../constants';
import { weierstrass } from '../../utils/ec';
import { EDataAvailabilityMode, ResourceBounds } from '../api';
import { CairoEnum } from '../cairoEnum';
import { ValuesType } from '../helpers/valuesType';
import { CompiledContract, CompiledSierraCasm, ContractClass } from './contract';

export type WeierstrassSignatureType = weierstrass.SignatureType;
Expand Down Expand Up @@ -100,6 +101,11 @@ export type DeclareContractPayload = {
compiledClassHash?: string;
};

/**
* DeclareContractPayload with classHash or contract defined
*/
export type ContractClassIdentifier = DeclareContractPayload | { classHash: string };

export type CompleteDeclareContractPayload = {
contract: CompiledContract | string;
classHash: string;
Expand Down

0 comments on commit 9fdf54f

Please sign in to comment.